github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/http3/server_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"golang.org/x/exp/slog"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"runtime"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/metacubex/quic-go"
    17  	mockquic "github.com/metacubex/quic-go/internal/mocks/quic"
    18  	"github.com/metacubex/quic-go/internal/protocol"
    19  	"github.com/metacubex/quic-go/internal/testdata"
    20  	"github.com/metacubex/quic-go/quicvarint"
    21  
    22  	"github.com/quic-go/qpack"
    23  	"go.uber.org/mock/gomock"
    24  
    25  	. "github.com/onsi/ginkgo/v2"
    26  	. "github.com/onsi/gomega"
    27  	gmtypes "github.com/onsi/gomega/types"
    28  )
    29  
    30  type mockAddr struct{ addr string }
    31  
    32  func (ma *mockAddr) Network() string { return "udp" }
    33  func (ma *mockAddr) String() string  { return ma.addr }
    34  
    35  type mockAddrListener struct {
    36  	*MockQUICEarlyListener
    37  	addr *mockAddr
    38  }
    39  
    40  func (m *mockAddrListener) Addr() net.Addr {
    41  	_ = m.MockQUICEarlyListener.Addr()
    42  	return m.addr
    43  }
    44  
    45  func newMockAddrListener(addr string) *mockAddrListener {
    46  	return &mockAddrListener{
    47  		MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl),
    48  		addr:                  &mockAddr{addr: addr},
    49  	}
    50  }
    51  
    52  type noPortListener struct {
    53  	*mockAddrListener
    54  }
    55  
    56  func (m *noPortListener) Addr() net.Addr {
    57  	_ = m.mockAddrListener.Addr()
    58  	return &net.UnixAddr{
    59  		Net:  "unix",
    60  		Name: "/tmp/quic.sock",
    61  	}
    62  }
    63  
    64  var _ = Describe("Server", func() {
    65  	var (
    66  		s                  *Server
    67  		origQuicListenAddr = quicListenAddr
    68  	)
    69  	type testConnContextKey string
    70  
    71  	BeforeEach(func() {
    72  		s = &Server{
    73  			TLSConfig: testdata.GetTLSConfig(),
    74  			ConnContext: func(ctx context.Context, c quic.Connection) context.Context {
    75  				return context.WithValue(ctx, testConnContextKey("test"), c)
    76  			},
    77  		}
    78  		origQuicListenAddr = quicListenAddr
    79  	})
    80  
    81  	AfterEach(func() {
    82  		quicListenAddr = origQuicListenAddr
    83  	})
    84  
    85  	Context("handling requests", func() {
    86  		var (
    87  			qpackDecoder       *qpack.Decoder
    88  			str                *mockquic.MockStream
    89  			conn               *connection
    90  			exampleGetRequest  *http.Request
    91  			examplePostRequest *http.Request
    92  		)
    93  		reqContext, reqContextCancel := context.WithCancel(context.Background())
    94  
    95  		decodeHeader := func(str io.Reader) map[string][]string {
    96  			fields := make(map[string][]string)
    97  			decoder := qpack.NewDecoder(nil)
    98  
    99  			fp := frameParser{r: str}
   100  			frame, err := fp.ParseNext()
   101  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   102  			ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
   103  			headersFrame := frame.(*headersFrame)
   104  			data := make([]byte, headersFrame.Length)
   105  			_, err = io.ReadFull(str, data)
   106  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   107  			hfs, err := decoder.DecodeFull(data)
   108  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   109  			for _, p := range hfs {
   110  				fields[p.Name] = append(fields[p.Name], p.Value)
   111  			}
   112  			return fields
   113  		}
   114  
   115  		encodeRequest := func(req *http.Request) []byte {
   116  			buf := &bytes.Buffer{}
   117  			str := mockquic.NewMockStream(mockCtrl)
   118  			str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   119  			rw := newRequestWriter()
   120  			Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
   121  			return buf.Bytes()
   122  		}
   123  
   124  		setRequest := func(data []byte) {
   125  			buf := bytes.NewBuffer(data)
   126  			str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   127  				if buf.Len() == 0 {
   128  					return 0, io.EOF
   129  				}
   130  				return buf.Read(p)
   131  			}).AnyTimes()
   132  		}
   133  
   134  		BeforeEach(func() {
   135  			var err error
   136  			exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil)
   137  			Expect(err).ToNot(HaveOccurred())
   138  			examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar")))
   139  			Expect(err).ToNot(HaveOccurred())
   140  
   141  			qpackDecoder = qpack.NewDecoder(nil)
   142  			str = mockquic.NewMockStream(mockCtrl)
   143  			str.EXPECT().Context().Return(reqContext).AnyTimes()
   144  			str.EXPECT().StreamID().AnyTimes()
   145  			qconn := mockquic.NewMockEarlyConnection(mockCtrl)
   146  			addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   147  			qconn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
   148  			qconn.EXPECT().LocalAddr().AnyTimes()
   149  			qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
   150  			qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
   151  			conn = newConnection(qconn, false, protocol.PerspectiveServer, nil)
   152  		})
   153  
   154  		It("calls the HTTP handler function", func() {
   155  			requestChan := make(chan *http.Request, 1)
   156  			s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
   157  				requestChan <- r
   158  			})
   159  
   160  			setRequest(encodeRequest(exampleGetRequest))
   161  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   162  				return len(p), nil
   163  			}).AnyTimes()
   164  			str.EXPECT().CancelRead(gomock.Any())
   165  			str.EXPECT().Close()
   166  
   167  			s.handleRequest(conn, str, nil, qpackDecoder)
   168  			var req *http.Request
   169  			Eventually(requestChan).Should(Receive(&req))
   170  			Expect(req.Host).To(Equal("www.example.com"))
   171  			Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337"))
   172  			Expect(req.Context().Value(ServerContextKey)).To(Equal(s))
   173  			Expect(req.Context().Value(testConnContextKey("test"))).To(Equal(conn.Connection))
   174  		})
   175  
   176  		It("returns 200 with an empty handler", func() {
   177  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   178  
   179  			responseBuf := &bytes.Buffer{}
   180  			setRequest(encodeRequest(exampleGetRequest))
   181  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   182  			str.EXPECT().CancelRead(gomock.Any())
   183  			str.EXPECT().Close()
   184  
   185  			s.handleRequest(conn, str, nil, qpackDecoder)
   186  			hfs := decodeHeader(responseBuf)
   187  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   188  		})
   189  
   190  		It("sets Content-Length when the handler doesn't flush to the client", func() {
   191  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   192  				w.Write([]byte("foobar"))
   193  			})
   194  
   195  			responseBuf := &bytes.Buffer{}
   196  			setRequest(encodeRequest(exampleGetRequest))
   197  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   198  			str.EXPECT().CancelRead(gomock.Any())
   199  			str.EXPECT().Close()
   200  
   201  			s.handleRequest(conn, str, nil, qpackDecoder)
   202  			hfs := decodeHeader(responseBuf)
   203  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   204  			Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"}))
   205  			// status, content-length, date, content-type
   206  			Expect(hfs).To(HaveLen(4))
   207  		})
   208  
   209  		It("sets Content-Type when WriteHeader is called but response is not flushed", func() {
   210  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   211  				w.WriteHeader(http.StatusNotFound)
   212  				w.Write([]byte("<html></html>"))
   213  			})
   214  
   215  			responseBuf := &bytes.Buffer{}
   216  			setRequest(encodeRequest(exampleGetRequest))
   217  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   218  			str.EXPECT().CancelRead(gomock.Any())
   219  			str.EXPECT().Close()
   220  
   221  			s.handleRequest(conn, str, nil, qpackDecoder)
   222  			hfs := decodeHeader(responseBuf)
   223  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"404"}))
   224  			Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"}))
   225  			Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"}))
   226  		})
   227  
   228  		It("not sets Content-Length when the handler flushes to the client", func() {
   229  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   230  				w.Write([]byte("foobar"))
   231  				// force flush
   232  				w.(http.Flusher).Flush()
   233  			})
   234  
   235  			responseBuf := &bytes.Buffer{}
   236  			setRequest(encodeRequest(exampleGetRequest))
   237  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   238  			str.EXPECT().CancelRead(gomock.Any())
   239  			str.EXPECT().Close()
   240  
   241  			s.handleRequest(conn, str, nil, qpackDecoder)
   242  			hfs := decodeHeader(responseBuf)
   243  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   244  			// status, date, content-type
   245  			Expect(hfs).To(HaveLen(3))
   246  		})
   247  
   248  		It("ignores calls to Write for responses to HEAD requests", func() {
   249  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   250  				w.Write([]byte("foobar"))
   251  			})
   252  
   253  			headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil)
   254  			Expect(err).ToNot(HaveOccurred())
   255  			responseBuf := &bytes.Buffer{}
   256  			setRequest(encodeRequest(headRequest))
   257  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   258  			str.EXPECT().CancelRead(gomock.Any())
   259  			str.EXPECT().Close()
   260  
   261  			s.handleRequest(conn, str, nil, qpackDecoder)
   262  			hfs := decodeHeader(responseBuf)
   263  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   264  			Expect(responseBuf.Bytes()).To(BeEmpty())
   265  		})
   266  
   267  		It("response to HEAD request should also do content sniffing", func() {
   268  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   269  				w.Write([]byte("<html></html>"))
   270  			})
   271  
   272  			headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil)
   273  			Expect(err).ToNot(HaveOccurred())
   274  			responseBuf := &bytes.Buffer{}
   275  			setRequest(encodeRequest(headRequest))
   276  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   277  			str.EXPECT().CancelRead(gomock.Any())
   278  			str.EXPECT().Close()
   279  
   280  			s.handleRequest(conn, str, nil, qpackDecoder)
   281  			hfs := decodeHeader(responseBuf)
   282  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   283  			Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"}))
   284  			Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"}))
   285  		})
   286  
   287  		It("handles an aborting handler", func() {
   288  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   289  				panic(http.ErrAbortHandler)
   290  			})
   291  
   292  			responseBuf := &bytes.Buffer{}
   293  			setRequest(encodeRequest(exampleGetRequest))
   294  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   295  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
   296  			str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
   297  
   298  			s.handleRequest(conn, str, nil, qpackDecoder)
   299  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   300  		})
   301  
   302  		It("handles a panicking handler", func() {
   303  			var logBuf bytes.Buffer
   304  			s.Logger = slog.New(slog.NewTextHandler(&logBuf, nil))
   305  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   306  				panic("foobar")
   307  			})
   308  
   309  			responseBuf := &bytes.Buffer{}
   310  			setRequest(encodeRequest(exampleGetRequest))
   311  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   312  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
   313  			str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
   314  
   315  			s.handleRequest(conn, str, nil, qpackDecoder)
   316  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   317  			Expect(logBuf.String()).To(ContainSubstring("http: panic serving"))
   318  			Expect(logBuf.String()).To(ContainSubstring("foobar"))
   319  		})
   320  
   321  		Context("hijacking bidirectional streams", func() {
   322  			var conn *mockquic.MockEarlyConnection
   323  			testDone := make(chan struct{})
   324  
   325  			BeforeEach(func() {
   326  				testDone = make(chan struct{})
   327  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   328  				controlStr := mockquic.NewMockStream(mockCtrl)
   329  				controlStr.EXPECT().Write(gomock.Any())
   330  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   331  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   332  				conn.EXPECT().LocalAddr().AnyTimes()
   333  			})
   334  
   335  			AfterEach(func() { testDone <- struct{}{} })
   336  
   337  			It("hijacks a bidirectional stream of unknown frame type", func() {
   338  				id := quic.ConnectionTracingID(1337)
   339  				frameTypeChan := make(chan FrameType, 1)
   340  				s.StreamHijacker = func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
   341  					defer GinkgoRecover()
   342  					Expect(e).ToNot(HaveOccurred())
   343  					Expect(connTracingID).To(Equal(id))
   344  					frameTypeChan <- ft
   345  					return true, nil
   346  				}
   347  
   348  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   349  				unknownStr := mockquic.NewMockStream(mockCtrl)
   350  				unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes()
   351  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   352  				unknownStr.EXPECT().StreamID().AnyTimes()
   353  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   354  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   355  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   356  					<-testDone
   357  					return nil, errors.New("test done")
   358  				})
   359  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
   360  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   361  				s.handleConn(conn)
   362  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   363  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   364  			})
   365  
   366  			It("cancels writing when hijacker didn't hijack a bidirectional stream", func() {
   367  				frameTypeChan := make(chan FrameType, 1)
   368  				s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
   369  					Expect(e).ToNot(HaveOccurred())
   370  					frameTypeChan <- ft
   371  					return false, nil
   372  				}
   373  
   374  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   375  				unknownStr := mockquic.NewMockStream(mockCtrl)
   376  				unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes()
   377  				unknownStr.EXPECT().StreamID().AnyTimes()
   378  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   379  				unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   380  				unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   381  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   382  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   383  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   384  					<-testDone
   385  					return nil, errors.New("test done")
   386  				})
   387  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
   388  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   389  				s.handleConn(conn)
   390  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   391  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   392  			})
   393  
   394  			It("cancels writing when hijacker returned error", func() {
   395  				frameTypeChan := make(chan FrameType, 1)
   396  				s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
   397  					Expect(e).ToNot(HaveOccurred())
   398  					frameTypeChan <- ft
   399  					return false, errors.New("error in hijacker")
   400  				}
   401  
   402  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   403  				unknownStr := mockquic.NewMockStream(mockCtrl)
   404  				unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes()
   405  				unknownStr.EXPECT().StreamID().AnyTimes()
   406  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   407  				unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   408  				unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   409  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   410  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   411  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   412  					<-testDone
   413  					return nil, errors.New("test done")
   414  				})
   415  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
   416  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   417  				s.handleConn(conn)
   418  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   419  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   420  			})
   421  
   422  			It("handles errors that occur when reading the stream type", func() {
   423  				const strID = protocol.StreamID(1234 * 4)
   424  				testErr := errors.New("test error")
   425  				done := make(chan struct{})
   426  				s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, err error) (bool, error) {
   427  					defer close(done)
   428  					Expect(ft).To(BeZero())
   429  					Expect(str.StreamID()).To(Equal(strID))
   430  					Expect(err).To(MatchError(testErr))
   431  					return true, nil
   432  				}
   433  				unknownStr := mockquic.NewMockStream(mockCtrl)
   434  				unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes()
   435  				unknownStr.EXPECT().StreamID().Return(strID).AnyTimes()
   436  				unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
   437  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   438  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   439  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   440  					<-testDone
   441  					return nil, errors.New("test done")
   442  				})
   443  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
   444  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   445  				s.handleConn(conn)
   446  				Eventually(done).Should(BeClosed())
   447  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   448  			})
   449  		})
   450  
   451  		Context("hijacking unidirectional streams", func() {
   452  			var conn *mockquic.MockEarlyConnection
   453  			testDone := make(chan struct{})
   454  
   455  			BeforeEach(func() {
   456  				testDone = make(chan struct{})
   457  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   458  				controlStr := mockquic.NewMockStream(mockCtrl)
   459  				controlStr.EXPECT().Write(gomock.Any())
   460  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   461  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   462  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   463  				conn.EXPECT().LocalAddr().AnyTimes()
   464  			})
   465  
   466  			AfterEach(func() { testDone <- struct{}{} })
   467  
   468  			It("hijacks an unidirectional stream of unknown stream type", func() {
   469  				id := quic.ConnectionTracingID(42)
   470  				streamTypeChan := make(chan StreamType, 1)
   471  				s.UniStreamHijacker = func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
   472  					Expect(err).ToNot(HaveOccurred())
   473  					Expect(connTracingID).To(Equal(id))
   474  					streamTypeChan <- st
   475  					return true
   476  				}
   477  
   478  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   479  				unknownStr := mockquic.NewMockStream(mockCtrl)
   480  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   481  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   482  					return unknownStr, nil
   483  				})
   484  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   485  					<-testDone
   486  					return nil, errors.New("test done")
   487  				})
   488  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
   489  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   490  				s.handleConn(conn)
   491  				Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   492  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   493  			})
   494  
   495  			It("handles errors that occur when reading the stream type", func() {
   496  				testErr := errors.New("test error")
   497  				done := make(chan struct{})
   498  				unknownStr := mockquic.NewMockStream(mockCtrl)
   499  				s.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool {
   500  					defer close(done)
   501  					Expect(st).To(BeZero())
   502  					Expect(str).To(Equal(unknownStr))
   503  					Expect(err).To(MatchError(testErr))
   504  					return true
   505  				}
   506  
   507  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr })
   508  				conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
   509  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   510  					<-testDone
   511  					return nil, errors.New("test done")
   512  				})
   513  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
   514  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   515  				s.handleConn(conn)
   516  				Eventually(done).Should(BeClosed())
   517  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   518  			})
   519  
   520  			It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
   521  				streamTypeChan := make(chan StreamType, 1)
   522  				s.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
   523  					Expect(err).ToNot(HaveOccurred())
   524  					streamTypeChan <- st
   525  					return false
   526  				}
   527  
   528  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   529  				unknownStr := mockquic.NewMockStream(mockCtrl)
   530  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   531  				unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   532  
   533  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   534  					return unknownStr, nil
   535  				})
   536  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   537  					<-testDone
   538  					return nil, errors.New("test done")
   539  				})
   540  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
   541  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   542  				s.handleConn(conn)
   543  				Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   544  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   545  			})
   546  		})
   547  
   548  		Context("stream- and connection-level errors", func() {
   549  			var conn *mockquic.MockEarlyConnection
   550  			testDone := make(chan struct{})
   551  
   552  			BeforeEach(func() {
   553  				testDone = make(chan struct{})
   554  				addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   555  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   556  				controlStr := mockquic.NewMockStream(mockCtrl)
   557  				controlStr.EXPECT().Write(gomock.Any())
   558  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   559  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   560  					<-testDone
   561  					return nil, errors.New("test done")
   562  				})
   563  				conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
   564  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   565  				conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
   566  				conn.EXPECT().LocalAddr().AnyTimes()
   567  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
   568  				conn.EXPECT().Context().Return(context.Background()).AnyTimes()
   569  			})
   570  
   571  			AfterEach(func() { testDone <- struct{}{} })
   572  
   573  			It("cancels reading when client sends a body in GET request", func() {
   574  				var handlerCalled bool
   575  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   576  					handlerCalled = true
   577  				})
   578  
   579  				requestData := encodeRequest(exampleGetRequest)
   580  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   581  				b = append(b, []byte("foobar")...)
   582  				responseBuf := &bytes.Buffer{}
   583  				setRequest(append(requestData, b...))
   584  				done := make(chan struct{})
   585  				str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   586  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   587  				str.EXPECT().Close().Do(func() error { close(done); return nil })
   588  
   589  				s.handleConn(conn)
   590  				Eventually(done).Should(BeClosed())
   591  				hfs := decodeHeader(responseBuf)
   592  				Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   593  				Expect(handlerCalled).To(BeTrue())
   594  			})
   595  
   596  			It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() {
   597  				handlerCalled := make(chan struct{})
   598  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   599  					defer close(handlerCalled)
   600  					w.(HTTPStreamer).HTTPStream()
   601  					str.Write([]byte("foobar"))
   602  				})
   603  
   604  				requestData := encodeRequest(exampleGetRequest)
   605  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   606  				b = append(b, []byte("foobar")...)
   607  				setRequest(append(requestData, b...))
   608  				var buf bytes.Buffer
   609  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   610  
   611  				s.handleConn(conn)
   612  				Eventually(handlerCalled).Should(BeClosed())
   613  
   614  				// The buffer is expected to contain:
   615  				// 1. The response header (in a HEADERS frame)
   616  				// 2. the "foobar" (unframed)
   617  				fp := frameParser{r: &buf}
   618  				frame, err := fp.ParseNext()
   619  				Expect(err).ToNot(HaveOccurred())
   620  				Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
   621  				df := frame.(*headersFrame)
   622  				data := make([]byte, df.Length)
   623  				_, err = io.ReadFull(&buf, data)
   624  				Expect(err).ToNot(HaveOccurred())
   625  				hdrs, err := qpackDecoder.DecodeFull(data)
   626  				Expect(err).ToNot(HaveOccurred())
   627  				Expect(hdrs).To(ContainElement(qpack.HeaderField{Name: ":status", Value: "200"}))
   628  				Expect(buf.Bytes()).To(Equal([]byte("foobar")))
   629  			})
   630  
   631  			It("errors when the client sends a too large header frame", func() {
   632  				s.MaxHeaderBytes = 20
   633  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   634  					Fail("Handler should not be called.")
   635  				})
   636  
   637  				requestData := encodeRequest(exampleGetRequest)
   638  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   639  				b = append(b, []byte("foobar")...)
   640  				responseBuf := &bytes.Buffer{}
   641  				setRequest(append(requestData, b...))
   642  				done := make(chan struct{})
   643  				str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   644  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
   645  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
   646  
   647  				s.handleConn(conn)
   648  				Eventually(done).Should(BeClosed())
   649  			})
   650  
   651  			It("handles a request for which the client immediately resets the stream", func() {
   652  				handlerCalled := make(chan struct{})
   653  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   654  					close(handlerCalled)
   655  				})
   656  
   657  				testErr := errors.New("stream reset")
   658  				done := make(chan struct{})
   659  				str.EXPECT().Read(gomock.Any()).Return(0, testErr)
   660  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   661  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
   662  
   663  				s.handleConn(conn)
   664  				Consistently(handlerCalled).ShouldNot(BeClosed())
   665  			})
   666  
   667  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   668  				handlerCalled := make(chan struct{})
   669  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   670  					close(handlerCalled)
   671  				})
   672  
   673  				b := (&dataFrame{}).Append(nil)
   674  				setRequest(b)
   675  				str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   676  					return len(p), nil
   677  				}).AnyTimes()
   678  
   679  				done := make(chan struct{})
   680  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   681  					close(done)
   682  					return nil
   683  				})
   684  				s.handleConn(conn)
   685  				Eventually(done).Should(BeClosed())
   686  			})
   687  
   688  			It("rejects a request that has too large request headers", func() {
   689  				handlerCalled := make(chan struct{})
   690  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   691  					close(handlerCalled)
   692  				})
   693  
   694  				// use 2*DefaultMaxHeaderBytes here. qpack will compress the request,
   695  				// but the request will still end up larger than DefaultMaxHeaderBytes.
   696  				url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2)
   697  				req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil)
   698  				Expect(err).ToNot(HaveOccurred())
   699  				setRequest(encodeRequest(req))
   700  				str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   701  					return len(p), nil
   702  				}).AnyTimes()
   703  				done := make(chan struct{})
   704  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
   705  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
   706  
   707  				s.handleConn(conn)
   708  				Eventually(done).Should(BeClosed())
   709  			})
   710  		})
   711  
   712  		It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
   713  			handlerCalled := make(chan struct{})
   714  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   715  				r.Body = struct {
   716  					io.Reader
   717  					io.Closer
   718  				}{}
   719  				close(handlerCalled)
   720  			})
   721  
   722  			setRequest(encodeRequest(examplePostRequest))
   723  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   724  				return len(p), nil
   725  			}).AnyTimes()
   726  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   727  			str.EXPECT().Close()
   728  
   729  			s.handleRequest(conn, str, nil, qpackDecoder)
   730  			Eventually(handlerCalled).Should(BeClosed())
   731  		})
   732  
   733  		It("cancels the request context when the stream is closed", func() {
   734  			handlerCalled := make(chan struct{})
   735  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   736  				defer GinkgoRecover()
   737  				Expect(r.Context().Done()).To(BeClosed())
   738  				Expect(r.Context().Err()).To(MatchError(context.Canceled))
   739  				close(handlerCalled)
   740  			})
   741  			setRequest(encodeRequest(examplePostRequest))
   742  
   743  			reqContextCancel()
   744  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   745  				return len(p), nil
   746  			}).AnyTimes()
   747  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   748  			str.EXPECT().Close()
   749  
   750  			s.handleRequest(conn, str, nil, qpackDecoder)
   751  			Eventually(handlerCalled).Should(BeClosed())
   752  		})
   753  	})
   754  
   755  	Context("setting http headers", func() {
   756  		BeforeEach(func() {
   757  			s.QUICConfig = &quic.Config{Versions: []protocol.Version{protocol.Version1}}
   758  		})
   759  
   760  		var ln1 QUICEarlyListener
   761  		var ln2 QUICEarlyListener
   762  		expected := http.Header{
   763  			"Alt-Svc": {`h3=":443"; ma=2592000`},
   764  		}
   765  
   766  		addListener := func(addr string, ln *QUICEarlyListener) {
   767  			mln := newMockAddrListener(addr)
   768  			mln.EXPECT().Addr()
   769  			*ln = mln
   770  			s.addListener(ln)
   771  		}
   772  
   773  		removeListener := func(ln *QUICEarlyListener) {
   774  			s.removeListener(ln)
   775  		}
   776  
   777  		checkSetHeaders := func(expected gmtypes.GomegaMatcher) {
   778  			hdr := http.Header{}
   779  			Expect(s.SetQUICHeaders(hdr)).To(Succeed())
   780  			Expect(hdr).To(expected)
   781  		}
   782  
   783  		checkSetHeaderError := func() {
   784  			hdr := http.Header{}
   785  			Expect(s.SetQUICHeaders(hdr)).To(Equal(ErrNoAltSvcPort))
   786  		}
   787  
   788  		It("sets proper headers with numeric port", func() {
   789  			addListener(":443", &ln1)
   790  			checkSetHeaders(Equal(expected))
   791  			removeListener(&ln1)
   792  			checkSetHeaderError()
   793  		})
   794  
   795  		It("sets proper headers with full addr", func() {
   796  			addListener("127.0.0.1:443", &ln1)
   797  			checkSetHeaders(Equal(expected))
   798  			removeListener(&ln1)
   799  			checkSetHeaderError()
   800  		})
   801  
   802  		It("sets proper headers with string port", func() {
   803  			addListener(":https", &ln1)
   804  			checkSetHeaders(Equal(expected))
   805  			removeListener(&ln1)
   806  			checkSetHeaderError()
   807  		})
   808  
   809  		It("works multiple times", func() {
   810  			addListener(":https", &ln1)
   811  			checkSetHeaders(Equal(expected))
   812  			checkSetHeaders(Equal(expected))
   813  			removeListener(&ln1)
   814  			checkSetHeaderError()
   815  		})
   816  
   817  		It("works if the quic.Config sets QUIC versions", func() {
   818  			s.QUICConfig.Versions = []quic.Version{quic.Version1, quic.Version2}
   819  			addListener(":443", &ln1)
   820  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}}))
   821  			removeListener(&ln1)
   822  			checkSetHeaderError()
   823  		})
   824  
   825  		It("uses s.Port if set to a non-zero value", func() {
   826  			s.Port = 8443
   827  			addListener(":443", &ln1)
   828  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}}))
   829  			removeListener(&ln1)
   830  			checkSetHeaderError()
   831  		})
   832  
   833  		It("uses s.Addr if listeners don't have ports available", func() {
   834  			s.Addr = ":443"
   835  			var logBuf bytes.Buffer
   836  			s.Logger = slog.New(slog.NewTextHandler(&logBuf, nil))
   837  			mln := &noPortListener{newMockAddrListener("")}
   838  			mln.EXPECT().Addr()
   839  			ln1 = mln
   840  			s.addListener(&ln1)
   841  			checkSetHeaders(Equal(expected))
   842  			s.removeListener(&ln1)
   843  			checkSetHeaderError()
   844  			Expect(logBuf.String()).To(ContainSubstring("Unable to extract port from listener, will not be announced using SetQUICHeaders"))
   845  		})
   846  
   847  		It("properly announces multiple listeners", func() {
   848  			addListener(":443", &ln1)
   849  			addListener(":8443", &ln2)
   850  			checkSetHeaders(Or(
   851  				Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}),
   852  				Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}),
   853  			))
   854  			removeListener(&ln1)
   855  			removeListener(&ln2)
   856  			checkSetHeaderError()
   857  		})
   858  
   859  		It("doesn't duplicate Alt-Svc values", func() {
   860  			s.QUICConfig.Versions = []quic.Version{quic.Version1, quic.Version1}
   861  			addListener(":443", &ln1)
   862  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}}))
   863  			removeListener(&ln1)
   864  			checkSetHeaderError()
   865  		})
   866  	})
   867  
   868  	It("errors when ListenAndServe is called with s.TLSConfig nil", func() {
   869  		Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig))
   870  	})
   871  
   872  	It("should nop-Close() when s.server is nil", func() {
   873  		Expect((&Server{}).Close()).To(Succeed())
   874  	})
   875  
   876  	It("errors when ListenAndServeTLS is called after Close", func() {
   877  		serv := &Server{}
   878  		Expect(serv.Close()).To(Succeed())
   879  		Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed))
   880  	})
   881  
   882  	It("handles concurrent Serve and Close", func() {
   883  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   884  		Expect(err).ToNot(HaveOccurred())
   885  		c, err := net.ListenUDP("udp", addr)
   886  		Expect(err).ToNot(HaveOccurred())
   887  		done := make(chan struct{})
   888  		go func() {
   889  			defer GinkgoRecover()
   890  			defer close(done)
   891  			s.Serve(c)
   892  		}()
   893  		runtime.Gosched()
   894  		s.Close()
   895  		Eventually(done).Should(BeClosed())
   896  	})
   897  
   898  	Context("ConfigureTLSConfig", func() {
   899  		It("advertises v1 by default", func() {
   900  			conf := ConfigureTLSConfig(testdata.GetTLSConfig())
   901  			ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.Version{quic.Version1}})
   902  			Expect(err).ToNot(HaveOccurred())
   903  			defer ln.Close()
   904  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
   905  			Expect(err).ToNot(HaveOccurred())
   906  			defer c.CloseWithError(0, "")
   907  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
   908  		})
   909  
   910  		It("sets the GetConfigForClient callback if no tls.Config is given", func() {
   911  			var receivedConf *tls.Config
   912  			quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) {
   913  				receivedConf = tlsConf
   914  				return nil, errors.New("listen err")
   915  			}
   916  			Expect(s.ListenAndServe()).To(HaveOccurred())
   917  			Expect(receivedConf).ToNot(BeNil())
   918  		})
   919  
   920  		It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() {
   921  			tlsConf := &tls.Config{
   922  				GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
   923  					c := testdata.GetTLSConfig()
   924  					c.NextProtos = []string{"foo", "bar"}
   925  					return c, nil
   926  				},
   927  			}
   928  
   929  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
   930  			Expect(err).ToNot(HaveOccurred())
   931  			defer ln.Close()
   932  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
   933  			Expect(err).ToNot(HaveOccurred())
   934  			defer c.CloseWithError(0, "")
   935  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
   936  		})
   937  
   938  		It("works if GetConfigForClient returns a nil tls.Config", func() {
   939  			tlsConf := testdata.GetTLSConfig()
   940  			tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }
   941  
   942  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
   943  			Expect(err).ToNot(HaveOccurred())
   944  			defer ln.Close()
   945  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
   946  			Expect(err).ToNot(HaveOccurred())
   947  			defer c.CloseWithError(0, "")
   948  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
   949  		})
   950  
   951  		It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
   952  			tlsClientConf := testdata.GetTLSConfig()
   953  			tlsClientConf.NextProtos = []string{"foo", "bar"}
   954  			tlsConf := &tls.Config{
   955  				GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
   956  					return tlsClientConf, nil
   957  				},
   958  			}
   959  
   960  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
   961  			Expect(err).ToNot(HaveOccurred())
   962  			defer ln.Close()
   963  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
   964  			Expect(err).ToNot(HaveOccurred())
   965  			defer c.CloseWithError(0, "")
   966  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
   967  			// check that the original config was not modified
   968  			Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"}))
   969  		})
   970  	})
   971  
   972  	Context("Serve", func() {
   973  		origQuicListen := quicListen
   974  
   975  		AfterEach(func() {
   976  			quicListen = origQuicListen
   977  		})
   978  
   979  		It("serves a packet conn", func() {
   980  			ln := newMockAddrListener(":443")
   981  			conn := &net.UDPConn{}
   982  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
   983  				Expect(c).To(Equal(conn))
   984  				return ln, nil
   985  			}
   986  
   987  			s := &Server{
   988  				TLSConfig: &tls.Config{},
   989  			}
   990  
   991  			stopAccept := make(chan struct{})
   992  			ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
   993  				<-stopAccept
   994  				return nil, errors.New("closed")
   995  			})
   996  			ln.EXPECT().Addr() // generate alt-svc headers
   997  			done := make(chan struct{})
   998  			go func() {
   999  				defer GinkgoRecover()
  1000  				defer close(done)
  1001  				s.Serve(conn)
  1002  			}()
  1003  
  1004  			Consistently(done).ShouldNot(BeClosed())
  1005  			ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil })
  1006  			Expect(s.Close()).To(Succeed())
  1007  			Eventually(done).Should(BeClosed())
  1008  		})
  1009  
  1010  		It("serves two packet conns", func() {
  1011  			ln1 := newMockAddrListener(":443")
  1012  			ln2 := newMockAddrListener(":8443")
  1013  			lns := make(chan QUICEarlyListener, 2)
  1014  			lns <- ln1
  1015  			lns <- ln2
  1016  			conn1 := &net.UDPConn{}
  1017  			conn2 := &net.UDPConn{}
  1018  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1019  				return <-lns, nil
  1020  			}
  1021  
  1022  			s := &Server{
  1023  				TLSConfig: &tls.Config{},
  1024  			}
  1025  
  1026  			stopAccept1 := make(chan struct{})
  1027  			ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1028  				<-stopAccept1
  1029  				return nil, errors.New("closed")
  1030  			})
  1031  			ln1.EXPECT().Addr() // generate alt-svc headers
  1032  			stopAccept2 := make(chan struct{})
  1033  			ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1034  				<-stopAccept2
  1035  				return nil, errors.New("closed")
  1036  			})
  1037  			ln2.EXPECT().Addr()
  1038  
  1039  			done1 := make(chan struct{})
  1040  			go func() {
  1041  				defer GinkgoRecover()
  1042  				defer close(done1)
  1043  				s.Serve(conn1)
  1044  			}()
  1045  			done2 := make(chan struct{})
  1046  			go func() {
  1047  				defer GinkgoRecover()
  1048  				defer close(done2)
  1049  				s.Serve(conn2)
  1050  			}()
  1051  
  1052  			Consistently(done1).ShouldNot(BeClosed())
  1053  			Expect(done2).ToNot(BeClosed())
  1054  			ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil })
  1055  			ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil })
  1056  			Expect(s.Close()).To(Succeed())
  1057  			Eventually(done1).Should(BeClosed())
  1058  			Eventually(done2).Should(BeClosed())
  1059  		})
  1060  	})
  1061  
  1062  	Context("ServeListener", func() {
  1063  		origQuicListen := quicListen
  1064  
  1065  		AfterEach(func() {
  1066  			quicListen = origQuicListen
  1067  		})
  1068  
  1069  		It("serves a listener", func() {
  1070  			var called int32
  1071  			ln := newMockAddrListener(":443")
  1072  			quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1073  				atomic.StoreInt32(&called, 1)
  1074  				return ln, nil
  1075  			}
  1076  
  1077  			s := &Server{}
  1078  
  1079  			stopAccept := make(chan struct{})
  1080  			ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1081  				<-stopAccept
  1082  				return nil, errors.New("closed")
  1083  			})
  1084  			ln.EXPECT().Addr() // generate alt-svc headers
  1085  			done := make(chan struct{})
  1086  			go func() {
  1087  				defer GinkgoRecover()
  1088  				defer close(done)
  1089  				s.ServeListener(ln)
  1090  			}()
  1091  
  1092  			Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
  1093  			Consistently(done).ShouldNot(BeClosed())
  1094  			ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil })
  1095  			Expect(s.Close()).To(Succeed())
  1096  			Eventually(done).Should(BeClosed())
  1097  		})
  1098  
  1099  		It("serves two listeners", func() {
  1100  			var called int32
  1101  			ln1 := newMockAddrListener(":443")
  1102  			ln2 := newMockAddrListener(":8443")
  1103  			lns := make(chan QUICEarlyListener, 2)
  1104  			lns <- ln1
  1105  			lns <- ln2
  1106  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1107  				atomic.StoreInt32(&called, 1)
  1108  				return <-lns, nil
  1109  			}
  1110  
  1111  			s := &Server{}
  1112  
  1113  			stopAccept1 := make(chan struct{})
  1114  			ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1115  				<-stopAccept1
  1116  				return nil, errors.New("closed")
  1117  			})
  1118  			ln1.EXPECT().Addr() // generate alt-svc headers
  1119  			stopAccept2 := make(chan struct{})
  1120  			ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1121  				<-stopAccept2
  1122  				return nil, errors.New("closed")
  1123  			})
  1124  			ln2.EXPECT().Addr()
  1125  
  1126  			done1 := make(chan struct{})
  1127  			go func() {
  1128  				defer GinkgoRecover()
  1129  				defer close(done1)
  1130  				s.ServeListener(ln1)
  1131  			}()
  1132  			done2 := make(chan struct{})
  1133  			go func() {
  1134  				defer GinkgoRecover()
  1135  				defer close(done2)
  1136  				s.ServeListener(ln2)
  1137  			}()
  1138  
  1139  			Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
  1140  			Consistently(done1).ShouldNot(BeClosed())
  1141  			Expect(done2).ToNot(BeClosed())
  1142  			ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil })
  1143  			ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil })
  1144  			Expect(s.Close()).To(Succeed())
  1145  			Eventually(done1).Should(BeClosed())
  1146  			Eventually(done2).Should(BeClosed())
  1147  		})
  1148  	})
  1149  
  1150  	Context("ServeQUICConn", func() {
  1151  		It("serves a QUIC connection", func() {
  1152  			mux := http.NewServeMux()
  1153  			mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) {
  1154  				w.Write([]byte("foobar"))
  1155  			})
  1156  			s.Handler = mux
  1157  			tlsConf := testdata.GetTLSConfig()
  1158  			tlsConf.NextProtos = []string{NextProtoH3}
  1159  			conn := mockquic.NewMockEarlyConnection(mockCtrl)
  1160  			controlStr := mockquic.NewMockStream(mockCtrl)
  1161  			controlStr.EXPECT().Write(gomock.Any())
  1162  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
  1163  			testDone := make(chan struct{})
  1164  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
  1165  				<-testDone
  1166  				return nil, errors.New("test done")
  1167  			}).MaxTimes(1)
  1168  			conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)})
  1169  			s.ServeQUICConn(conn)
  1170  			close(testDone)
  1171  		})
  1172  	})
  1173  
  1174  	Context("ListenAndServe", func() {
  1175  		BeforeEach(func() {
  1176  			s.Addr = "localhost:0"
  1177  		})
  1178  
  1179  		AfterEach(func() {
  1180  			Expect(s.Close()).To(Succeed())
  1181  		})
  1182  
  1183  		It("uses the quic.Config to start the QUIC server", func() {
  1184  			conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond}
  1185  			var receivedConf *quic.Config
  1186  			quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1187  				receivedConf = config
  1188  				return nil, errors.New("listen err")
  1189  			}
  1190  			s.QUICConfig = conf
  1191  			Expect(s.ListenAndServe()).To(HaveOccurred())
  1192  			Expect(receivedConf).To(Equal(conf))
  1193  		})
  1194  	})
  1195  
  1196  	It("closes gracefully", func() {
  1197  		Expect(s.CloseGracefully(0)).To(Succeed())
  1198  	})
  1199  
  1200  	It("errors when listening fails", func() {
  1201  		testErr := errors.New("listen error")
  1202  		quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1203  			return nil, testErr
  1204  		}
  1205  		fullpem, privkey := testdata.GetCertificatePaths()
  1206  		Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr))
  1207  	})
  1208  
  1209  	It("supports H3_DATAGRAM", func() {
  1210  		s.EnableDatagrams = true
  1211  		var receivedConf *quic.Config
  1212  		quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1213  			receivedConf = config
  1214  			return nil, errors.New("listen err")
  1215  		}
  1216  		Expect(s.ListenAndServe()).To(HaveOccurred())
  1217  		Expect(receivedConf.EnableDatagrams).To(BeTrue())
  1218  	})
  1219  })