github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/http3/server_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"io"
     9  	"log/slog"
    10  	"net"
    11  	"net/http"
    12  	"runtime"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/apernet/quic-go"
    17  	mockquic "github.com/apernet/quic-go/internal/mocks/quic"
    18  	"github.com/apernet/quic-go/internal/protocol"
    19  	"github.com/apernet/quic-go/internal/testdata"
    20  	"github.com/apernet/quic-go/quicvarint"
    21  
    22  	"github.com/quic-go/qpack"
    23  	"go.uber.org/mock/gomock"
    24  
    25  	. "github.com/onsi/ginkgo/v2"
    26  	. "github.com/onsi/gomega"
    27  	gmtypes "github.com/onsi/gomega/types"
    28  )
    29  
    30  type mockAddr struct{ addr string }
    31  
    32  func (ma *mockAddr) Network() string { return "udp" }
    33  func (ma *mockAddr) String() string  { return ma.addr }
    34  
    35  type mockAddrListener struct {
    36  	*MockQUICEarlyListener
    37  	addr *mockAddr
    38  }
    39  
    40  func (m *mockAddrListener) Addr() net.Addr {
    41  	_ = m.MockQUICEarlyListener.Addr()
    42  	return m.addr
    43  }
    44  
    45  func newMockAddrListener(addr string) *mockAddrListener {
    46  	return &mockAddrListener{
    47  		MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl),
    48  		addr:                  &mockAddr{addr: addr},
    49  	}
    50  }
    51  
    52  type noPortListener struct {
    53  	*mockAddrListener
    54  }
    55  
    56  func (m *noPortListener) Addr() net.Addr {
    57  	_ = m.mockAddrListener.Addr()
    58  	return &net.UnixAddr{
    59  		Net:  "unix",
    60  		Name: "/tmp/quic.sock",
    61  	}
    62  }
    63  
    64  var _ = Describe("Server", func() {
    65  	var (
    66  		s                  *Server
    67  		origQuicListenAddr = quicListenAddr
    68  	)
    69  	type testConnContextKey string
    70  
    71  	BeforeEach(func() {
    72  		s = &Server{
    73  			TLSConfig: testdata.GetTLSConfig(),
    74  			ConnContext: func(ctx context.Context, c quic.Connection) context.Context {
    75  				return context.WithValue(ctx, testConnContextKey("test"), c)
    76  			},
    77  		}
    78  		origQuicListenAddr = quicListenAddr
    79  	})
    80  
    81  	AfterEach(func() {
    82  		quicListenAddr = origQuicListenAddr
    83  	})
    84  
    85  	Context("handling requests", func() {
    86  		var (
    87  			qpackDecoder       *qpack.Decoder
    88  			str                *mockquic.MockStream
    89  			conn               *connection
    90  			exampleGetRequest  *http.Request
    91  			examplePostRequest *http.Request
    92  		)
    93  		reqContext := context.Background()
    94  
    95  		decodeHeader := func(str io.Reader) map[string][]string {
    96  			fields := make(map[string][]string)
    97  			decoder := qpack.NewDecoder(nil)
    98  
    99  			frame, err := parseNextFrame(str, nil)
   100  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   101  			ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
   102  			headersFrame := frame.(*headersFrame)
   103  			data := make([]byte, headersFrame.Length)
   104  			_, err = io.ReadFull(str, data)
   105  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   106  			hfs, err := decoder.DecodeFull(data)
   107  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   108  			for _, p := range hfs {
   109  				fields[p.Name] = append(fields[p.Name], p.Value)
   110  			}
   111  			return fields
   112  		}
   113  
   114  		encodeRequest := func(req *http.Request) []byte {
   115  			buf := &bytes.Buffer{}
   116  			str := mockquic.NewMockStream(mockCtrl)
   117  			str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   118  			rw := newRequestWriter()
   119  			Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
   120  			return buf.Bytes()
   121  		}
   122  
   123  		setRequest := func(data []byte) {
   124  			buf := bytes.NewBuffer(data)
   125  			str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   126  				if buf.Len() == 0 {
   127  					return 0, io.EOF
   128  				}
   129  				return buf.Read(p)
   130  			}).AnyTimes()
   131  		}
   132  
   133  		BeforeEach(func() {
   134  			var err error
   135  			exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil)
   136  			Expect(err).ToNot(HaveOccurred())
   137  			examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar")))
   138  			Expect(err).ToNot(HaveOccurred())
   139  
   140  			qpackDecoder = qpack.NewDecoder(nil)
   141  			str = mockquic.NewMockStream(mockCtrl)
   142  			str.EXPECT().StreamID().AnyTimes()
   143  			qconn := mockquic.NewMockEarlyConnection(mockCtrl)
   144  			addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   145  			qconn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
   146  			qconn.EXPECT().LocalAddr().AnyTimes()
   147  			qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
   148  			qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
   149  			conn = newConnection(qconn, false, protocol.PerspectiveServer, nil)
   150  		})
   151  
   152  		It("calls the HTTP handler function", func() {
   153  			requestChan := make(chan *http.Request, 1)
   154  			s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
   155  				requestChan <- r
   156  			})
   157  
   158  			setRequest(encodeRequest(exampleGetRequest))
   159  			str.EXPECT().Context().Return(reqContext)
   160  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   161  				return len(p), nil
   162  			}).AnyTimes()
   163  			str.EXPECT().CancelRead(gomock.Any())
   164  			str.EXPECT().Close()
   165  
   166  			s.handleRequest(conn, str, nil, qpackDecoder)
   167  			var req *http.Request
   168  			Eventually(requestChan).Should(Receive(&req))
   169  			Expect(req.Host).To(Equal("www.example.com"))
   170  			Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337"))
   171  			Expect(req.Context().Value(ServerContextKey)).To(Equal(s))
   172  			Expect(req.Context().Value(testConnContextKey("test"))).ToNot(Equal(nil))
   173  		})
   174  
   175  		It("returns 200 with an empty handler", func() {
   176  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   177  
   178  			responseBuf := &bytes.Buffer{}
   179  			setRequest(encodeRequest(exampleGetRequest))
   180  			str.EXPECT().Context().Return(reqContext)
   181  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   182  			str.EXPECT().CancelRead(gomock.Any())
   183  			str.EXPECT().Close()
   184  
   185  			s.handleRequest(conn, str, nil, qpackDecoder)
   186  			hfs := decodeHeader(responseBuf)
   187  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   188  		})
   189  
   190  		It("sets Content-Length when the handler doesn't flush to the client", func() {
   191  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   192  				w.Write([]byte("foobar"))
   193  			})
   194  
   195  			responseBuf := &bytes.Buffer{}
   196  			setRequest(encodeRequest(exampleGetRequest))
   197  			str.EXPECT().Context().Return(reqContext)
   198  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   199  			str.EXPECT().CancelRead(gomock.Any())
   200  			str.EXPECT().Close()
   201  
   202  			s.handleRequest(conn, str, nil, qpackDecoder)
   203  			hfs := decodeHeader(responseBuf)
   204  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   205  			Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"}))
   206  			// status, content-length, date, content-type
   207  			Expect(hfs).To(HaveLen(4))
   208  		})
   209  
   210  		It("sets Content-Type when WriteHeader is called but response is not flushed", func() {
   211  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   212  				w.WriteHeader(http.StatusNotFound)
   213  				w.Write([]byte("<html></html>"))
   214  			})
   215  
   216  			responseBuf := &bytes.Buffer{}
   217  			setRequest(encodeRequest(exampleGetRequest))
   218  			str.EXPECT().Context().Return(reqContext)
   219  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   220  			str.EXPECT().CancelRead(gomock.Any())
   221  			str.EXPECT().Close()
   222  
   223  			s.handleRequest(conn, str, nil, qpackDecoder)
   224  			hfs := decodeHeader(responseBuf)
   225  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"404"}))
   226  			Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"}))
   227  			Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"}))
   228  		})
   229  
   230  		It("not sets Content-Length when the handler flushes to the client", func() {
   231  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   232  				w.Write([]byte("foobar"))
   233  				// force flush
   234  				w.(http.Flusher).Flush()
   235  			})
   236  
   237  			responseBuf := &bytes.Buffer{}
   238  			setRequest(encodeRequest(exampleGetRequest))
   239  			str.EXPECT().Context().Return(reqContext)
   240  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   241  			str.EXPECT().CancelRead(gomock.Any())
   242  			str.EXPECT().Close()
   243  
   244  			s.handleRequest(conn, str, nil, qpackDecoder)
   245  			hfs := decodeHeader(responseBuf)
   246  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   247  			// status, date, content-type
   248  			Expect(hfs).To(HaveLen(3))
   249  		})
   250  
   251  		It("ignores calls to Write for responses to HEAD requests", func() {
   252  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   253  				w.Write([]byte("foobar"))
   254  			})
   255  
   256  			headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil)
   257  			Expect(err).ToNot(HaveOccurred())
   258  			responseBuf := &bytes.Buffer{}
   259  			setRequest(encodeRequest(headRequest))
   260  			str.EXPECT().Context().Return(reqContext)
   261  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   262  			str.EXPECT().CancelRead(gomock.Any())
   263  			str.EXPECT().Close()
   264  
   265  			s.handleRequest(conn, str, nil, qpackDecoder)
   266  			hfs := decodeHeader(responseBuf)
   267  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   268  			Expect(responseBuf.Bytes()).To(BeEmpty())
   269  		})
   270  
   271  		It("response to HEAD request should also do content sniffing", func() {
   272  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   273  				w.Write([]byte("<html></html>"))
   274  			})
   275  
   276  			headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil)
   277  			Expect(err).ToNot(HaveOccurred())
   278  			responseBuf := &bytes.Buffer{}
   279  			setRequest(encodeRequest(headRequest))
   280  			str.EXPECT().Context().Return(reqContext)
   281  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   282  			str.EXPECT().CancelRead(gomock.Any())
   283  			str.EXPECT().Close()
   284  
   285  			s.handleRequest(conn, str, nil, qpackDecoder)
   286  			hfs := decodeHeader(responseBuf)
   287  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   288  			Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"}))
   289  			Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"}))
   290  		})
   291  
   292  		It("handles an aborting handler", func() {
   293  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   294  				panic(http.ErrAbortHandler)
   295  			})
   296  
   297  			responseBuf := &bytes.Buffer{}
   298  			setRequest(encodeRequest(exampleGetRequest))
   299  			str.EXPECT().Context().Return(reqContext)
   300  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   301  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
   302  			str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
   303  
   304  			s.handleRequest(conn, str, nil, qpackDecoder)
   305  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   306  		})
   307  
   308  		It("handles a panicking handler", func() {
   309  			var logBuf bytes.Buffer
   310  			s.Logger = slog.New(slog.NewTextHandler(&logBuf, nil))
   311  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   312  				panic("foobar")
   313  			})
   314  
   315  			responseBuf := &bytes.Buffer{}
   316  			setRequest(encodeRequest(exampleGetRequest))
   317  			str.EXPECT().Context().Return(reqContext)
   318  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   319  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
   320  			str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
   321  
   322  			s.handleRequest(conn, str, nil, qpackDecoder)
   323  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   324  			Expect(logBuf.String()).To(ContainSubstring("http: panic serving"))
   325  			Expect(logBuf.String()).To(ContainSubstring("foobar"))
   326  		})
   327  
   328  		Context("hijacking bidirectional streams", func() {
   329  			var conn *mockquic.MockEarlyConnection
   330  			testDone := make(chan struct{})
   331  
   332  			BeforeEach(func() {
   333  				testDone = make(chan struct{})
   334  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   335  				controlStr := mockquic.NewMockStream(mockCtrl)
   336  				controlStr.EXPECT().Write(gomock.Any())
   337  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   338  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   339  				conn.EXPECT().LocalAddr().AnyTimes()
   340  			})
   341  
   342  			AfterEach(func() { testDone <- struct{}{} })
   343  
   344  			It("hijacks a bidirectional stream of unknown frame type", func() {
   345  				id := quic.ConnectionTracingID(1337)
   346  				frameTypeChan := make(chan FrameType, 1)
   347  				s.StreamHijacker = func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
   348  					defer GinkgoRecover()
   349  					Expect(e).ToNot(HaveOccurred())
   350  					Expect(connTracingID).To(Equal(id))
   351  					frameTypeChan <- ft
   352  					return true, nil
   353  				}
   354  
   355  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   356  				unknownStr := mockquic.NewMockStream(mockCtrl)
   357  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   358  				unknownStr.EXPECT().StreamID().AnyTimes()
   359  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   360  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   361  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   362  					<-testDone
   363  					return nil, errors.New("test done")
   364  				})
   365  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
   366  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   367  				s.handleConn(conn)
   368  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   369  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   370  			})
   371  
   372  			It("cancels writing when hijacker didn't hijack a bidirectional stream", func() {
   373  				frameTypeChan := make(chan FrameType, 1)
   374  				s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
   375  					Expect(e).ToNot(HaveOccurred())
   376  					frameTypeChan <- ft
   377  					return false, nil
   378  				}
   379  
   380  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   381  				unknownStr := mockquic.NewMockStream(mockCtrl)
   382  				unknownStr.EXPECT().StreamID().AnyTimes()
   383  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   384  				unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   385  				unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   386  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   387  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   388  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   389  					<-testDone
   390  					return nil, errors.New("test done")
   391  				})
   392  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
   393  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   394  				s.handleConn(conn)
   395  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   396  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   397  			})
   398  
   399  			It("cancels writing when hijacker returned error", func() {
   400  				frameTypeChan := make(chan FrameType, 1)
   401  				s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
   402  					Expect(e).ToNot(HaveOccurred())
   403  					frameTypeChan <- ft
   404  					return false, errors.New("error in hijacker")
   405  				}
   406  
   407  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   408  				unknownStr := mockquic.NewMockStream(mockCtrl)
   409  				unknownStr.EXPECT().StreamID().AnyTimes()
   410  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   411  				unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   412  				unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   413  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   414  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   415  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   416  					<-testDone
   417  					return nil, errors.New("test done")
   418  				})
   419  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
   420  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   421  				s.handleConn(conn)
   422  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   423  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   424  			})
   425  
   426  			It("handles errors that occur when reading the stream type", func() {
   427  				const strID = protocol.StreamID(1234 * 4)
   428  				testErr := errors.New("test error")
   429  				done := make(chan struct{})
   430  				unknownStr := mockquic.NewMockStream(mockCtrl)
   431  				s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, err error) (bool, error) {
   432  					defer close(done)
   433  					Expect(ft).To(BeZero())
   434  					Expect(str.StreamID()).To(Equal(strID))
   435  					Expect(err).To(MatchError(testErr))
   436  					return true, nil
   437  				}
   438  				unknownStr.EXPECT().StreamID().Return(strID).AnyTimes()
   439  				unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
   440  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   441  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   442  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   443  					<-testDone
   444  					return nil, errors.New("test done")
   445  				})
   446  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
   447  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   448  				s.handleConn(conn)
   449  				Eventually(done).Should(BeClosed())
   450  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   451  			})
   452  		})
   453  
   454  		Context("hijacking unidirectional streams", func() {
   455  			var conn *mockquic.MockEarlyConnection
   456  			testDone := make(chan struct{})
   457  
   458  			BeforeEach(func() {
   459  				testDone = make(chan struct{})
   460  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   461  				controlStr := mockquic.NewMockStream(mockCtrl)
   462  				controlStr.EXPECT().Write(gomock.Any())
   463  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   464  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   465  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   466  				conn.EXPECT().LocalAddr().AnyTimes()
   467  			})
   468  
   469  			AfterEach(func() { testDone <- struct{}{} })
   470  
   471  			It("hijacks an unidirectional stream of unknown stream type", func() {
   472  				id := quic.ConnectionTracingID(42)
   473  				streamTypeChan := make(chan StreamType, 1)
   474  				s.UniStreamHijacker = func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
   475  					Expect(err).ToNot(HaveOccurred())
   476  					Expect(connTracingID).To(Equal(id))
   477  					streamTypeChan <- st
   478  					return true
   479  				}
   480  
   481  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   482  				unknownStr := mockquic.NewMockStream(mockCtrl)
   483  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   484  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   485  					return unknownStr, nil
   486  				})
   487  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   488  					<-testDone
   489  					return nil, errors.New("test done")
   490  				})
   491  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
   492  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   493  				s.handleConn(conn)
   494  				Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   495  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   496  			})
   497  
   498  			It("handles errors that occur when reading the stream type", func() {
   499  				testErr := errors.New("test error")
   500  				done := make(chan struct{})
   501  				unknownStr := mockquic.NewMockStream(mockCtrl)
   502  				s.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool {
   503  					defer close(done)
   504  					Expect(st).To(BeZero())
   505  					Expect(str).To(Equal(unknownStr))
   506  					Expect(err).To(MatchError(testErr))
   507  					return true
   508  				}
   509  
   510  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr })
   511  				conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
   512  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   513  					<-testDone
   514  					return nil, errors.New("test done")
   515  				})
   516  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
   517  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   518  				s.handleConn(conn)
   519  				Eventually(done).Should(BeClosed())
   520  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   521  			})
   522  
   523  			It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
   524  				streamTypeChan := make(chan StreamType, 1)
   525  				s.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
   526  					Expect(err).ToNot(HaveOccurred())
   527  					streamTypeChan <- st
   528  					return false
   529  				}
   530  
   531  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   532  				unknownStr := mockquic.NewMockStream(mockCtrl)
   533  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   534  				unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   535  
   536  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   537  					return unknownStr, nil
   538  				})
   539  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   540  					<-testDone
   541  					return nil, errors.New("test done")
   542  				})
   543  				ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
   544  				conn.EXPECT().Context().Return(ctx).AnyTimes()
   545  				s.handleConn(conn)
   546  				Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   547  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   548  			})
   549  		})
   550  
   551  		Context("stream- and connection-level errors", func() {
   552  			var conn *mockquic.MockEarlyConnection
   553  			testDone := make(chan struct{})
   554  
   555  			BeforeEach(func() {
   556  				testDone = make(chan struct{})
   557  				addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   558  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   559  				controlStr := mockquic.NewMockStream(mockCtrl)
   560  				controlStr.EXPECT().Write(gomock.Any())
   561  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   562  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   563  					<-testDone
   564  					return nil, errors.New("test done")
   565  				})
   566  				conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
   567  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   568  				conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
   569  				conn.EXPECT().LocalAddr().AnyTimes()
   570  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
   571  				conn.EXPECT().Context().Return(context.Background()).AnyTimes()
   572  			})
   573  
   574  			AfterEach(func() { testDone <- struct{}{} })
   575  
   576  			It("cancels reading when client sends a body in GET request", func() {
   577  				var handlerCalled bool
   578  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   579  					handlerCalled = true
   580  				})
   581  
   582  				requestData := encodeRequest(exampleGetRequest)
   583  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   584  				b = append(b, []byte("foobar")...)
   585  				responseBuf := &bytes.Buffer{}
   586  				setRequest(append(requestData, b...))
   587  				done := make(chan struct{})
   588  				str.EXPECT().Context().Return(reqContext)
   589  				str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   590  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   591  				str.EXPECT().Close().Do(func() error { close(done); return nil })
   592  
   593  				s.handleConn(conn)
   594  				Eventually(done).Should(BeClosed())
   595  				hfs := decodeHeader(responseBuf)
   596  				Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   597  				Expect(handlerCalled).To(BeTrue())
   598  			})
   599  
   600  			It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() {
   601  				handlerCalled := make(chan struct{})
   602  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   603  					defer close(handlerCalled)
   604  					w.(HTTPStreamer).HTTPStream()
   605  					str.Write([]byte("foobar"))
   606  				})
   607  
   608  				requestData := encodeRequest(exampleGetRequest)
   609  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   610  				b = append(b, []byte("foobar")...)
   611  				setRequest(append(requestData, b...))
   612  				str.EXPECT().Context().Return(reqContext)
   613  				var buf bytes.Buffer
   614  				str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   615  
   616  				s.handleConn(conn)
   617  				Eventually(handlerCalled).Should(BeClosed())
   618  
   619  				// The buffer is expected to contain:
   620  				// 1. The response header (in a HEADERS frame)
   621  				// 2. the "foobar" (unframed)
   622  				frame, err := parseNextFrame(&buf, nil)
   623  				Expect(err).ToNot(HaveOccurred())
   624  				Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
   625  				df := frame.(*headersFrame)
   626  				data := make([]byte, df.Length)
   627  				_, err = io.ReadFull(&buf, data)
   628  				Expect(err).ToNot(HaveOccurred())
   629  				hdrs, err := qpackDecoder.DecodeFull(data)
   630  				Expect(err).ToNot(HaveOccurred())
   631  				Expect(hdrs).To(ContainElement(qpack.HeaderField{Name: ":status", Value: "200"}))
   632  				Expect(buf.Bytes()).To(Equal([]byte("foobar")))
   633  			})
   634  
   635  			It("errors when the client sends a too large header frame", func() {
   636  				s.MaxHeaderBytes = 20
   637  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   638  					Fail("Handler should not be called.")
   639  				})
   640  
   641  				requestData := encodeRequest(exampleGetRequest)
   642  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   643  				b = append(b, []byte("foobar")...)
   644  				responseBuf := &bytes.Buffer{}
   645  				setRequest(append(requestData, b...))
   646  				done := make(chan struct{})
   647  				str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   648  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
   649  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
   650  
   651  				s.handleConn(conn)
   652  				Eventually(done).Should(BeClosed())
   653  			})
   654  
   655  			It("handles a request for which the client immediately resets the stream", func() {
   656  				handlerCalled := make(chan struct{})
   657  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   658  					close(handlerCalled)
   659  				})
   660  
   661  				testErr := errors.New("stream reset")
   662  				done := make(chan struct{})
   663  				str.EXPECT().Read(gomock.Any()).Return(0, testErr)
   664  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   665  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
   666  
   667  				s.handleConn(conn)
   668  				Consistently(handlerCalled).ShouldNot(BeClosed())
   669  			})
   670  
   671  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   672  				handlerCalled := make(chan struct{})
   673  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   674  					close(handlerCalled)
   675  				})
   676  
   677  				b := (&dataFrame{}).Append(nil)
   678  				setRequest(b)
   679  				str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   680  					return len(p), nil
   681  				}).AnyTimes()
   682  
   683  				done := make(chan struct{})
   684  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   685  					close(done)
   686  					return nil
   687  				})
   688  				s.handleConn(conn)
   689  				Eventually(done).Should(BeClosed())
   690  			})
   691  
   692  			It("rejects a request that has too large request headers", func() {
   693  				handlerCalled := make(chan struct{})
   694  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   695  					close(handlerCalled)
   696  				})
   697  
   698  				// use 2*DefaultMaxHeaderBytes here. qpack will compress the request,
   699  				// but the request will still end up larger than DefaultMaxHeaderBytes.
   700  				url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2)
   701  				req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil)
   702  				Expect(err).ToNot(HaveOccurred())
   703  				setRequest(encodeRequest(req))
   704  				str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   705  					return len(p), nil
   706  				}).AnyTimes()
   707  				done := make(chan struct{})
   708  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
   709  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
   710  
   711  				s.handleConn(conn)
   712  				Eventually(done).Should(BeClosed())
   713  			})
   714  		})
   715  
   716  		It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
   717  			handlerCalled := make(chan struct{})
   718  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   719  				r.Body = struct {
   720  					io.Reader
   721  					io.Closer
   722  				}{}
   723  				close(handlerCalled)
   724  			})
   725  
   726  			setRequest(encodeRequest(examplePostRequest))
   727  			str.EXPECT().Context().Return(reqContext)
   728  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   729  				return len(p), nil
   730  			}).AnyTimes()
   731  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   732  			str.EXPECT().Close()
   733  
   734  			s.handleRequest(conn, str, nil, qpackDecoder)
   735  			Eventually(handlerCalled).Should(BeClosed())
   736  		})
   737  
   738  		It("cancels the request context when the stream is closed", func() {
   739  			handlerCalled := make(chan struct{})
   740  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   741  				defer GinkgoRecover()
   742  				Expect(r.Context().Done()).To(BeClosed())
   743  				Expect(r.Context().Err()).To(MatchError(context.Canceled))
   744  				close(handlerCalled)
   745  			})
   746  			setRequest(encodeRequest(examplePostRequest))
   747  
   748  			reqContext, cancel := context.WithCancel(context.Background())
   749  			cancel()
   750  			str.EXPECT().Context().Return(reqContext)
   751  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   752  				return len(p), nil
   753  			}).AnyTimes()
   754  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   755  			str.EXPECT().Close()
   756  
   757  			s.handleRequest(conn, str, nil, qpackDecoder)
   758  			Eventually(handlerCalled).Should(BeClosed())
   759  		})
   760  	})
   761  
   762  	Context("setting http headers", func() {
   763  		BeforeEach(func() {
   764  			s.QUICConfig = &quic.Config{Versions: []protocol.Version{protocol.Version1}}
   765  		})
   766  
   767  		var ln1 QUICEarlyListener
   768  		var ln2 QUICEarlyListener
   769  		expected := http.Header{
   770  			"Alt-Svc": {`h3=":443"; ma=2592000`},
   771  		}
   772  
   773  		addListener := func(addr string, ln *QUICEarlyListener) {
   774  			mln := newMockAddrListener(addr)
   775  			mln.EXPECT().Addr()
   776  			*ln = mln
   777  			s.addListener(ln)
   778  		}
   779  
   780  		removeListener := func(ln *QUICEarlyListener) {
   781  			s.removeListener(ln)
   782  		}
   783  
   784  		checkSetHeaders := func(expected gmtypes.GomegaMatcher) {
   785  			hdr := http.Header{}
   786  			Expect(s.SetQUICHeaders(hdr)).To(Succeed())
   787  			Expect(hdr).To(expected)
   788  		}
   789  
   790  		checkSetHeaderError := func() {
   791  			hdr := http.Header{}
   792  			Expect(s.SetQUICHeaders(hdr)).To(Equal(ErrNoAltSvcPort))
   793  		}
   794  
   795  		It("sets proper headers with numeric port", func() {
   796  			addListener(":443", &ln1)
   797  			checkSetHeaders(Equal(expected))
   798  			removeListener(&ln1)
   799  			checkSetHeaderError()
   800  		})
   801  
   802  		It("sets proper headers with full addr", func() {
   803  			addListener("127.0.0.1:443", &ln1)
   804  			checkSetHeaders(Equal(expected))
   805  			removeListener(&ln1)
   806  			checkSetHeaderError()
   807  		})
   808  
   809  		It("sets proper headers with string port", func() {
   810  			addListener(":https", &ln1)
   811  			checkSetHeaders(Equal(expected))
   812  			removeListener(&ln1)
   813  			checkSetHeaderError()
   814  		})
   815  
   816  		It("works multiple times", func() {
   817  			addListener(":https", &ln1)
   818  			checkSetHeaders(Equal(expected))
   819  			checkSetHeaders(Equal(expected))
   820  			removeListener(&ln1)
   821  			checkSetHeaderError()
   822  		})
   823  
   824  		It("works if the quic.Config sets QUIC versions", func() {
   825  			s.QUICConfig.Versions = []quic.Version{quic.Version1, quic.Version2}
   826  			addListener(":443", &ln1)
   827  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}}))
   828  			removeListener(&ln1)
   829  			checkSetHeaderError()
   830  		})
   831  
   832  		It("uses s.Port if set to a non-zero value", func() {
   833  			s.Port = 8443
   834  			addListener(":443", &ln1)
   835  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}}))
   836  			removeListener(&ln1)
   837  			checkSetHeaderError()
   838  		})
   839  
   840  		It("uses s.Addr if listeners don't have ports available", func() {
   841  			s.Addr = ":443"
   842  			var logBuf bytes.Buffer
   843  			s.Logger = slog.New(slog.NewTextHandler(&logBuf, nil))
   844  			mln := &noPortListener{newMockAddrListener("")}
   845  			mln.EXPECT().Addr()
   846  			ln1 = mln
   847  			s.addListener(&ln1)
   848  			checkSetHeaders(Equal(expected))
   849  			s.removeListener(&ln1)
   850  			checkSetHeaderError()
   851  			Expect(logBuf.String()).To(ContainSubstring("Unable to extract port from listener, will not be announced using SetQUICHeaders"))
   852  		})
   853  
   854  		It("properly announces multiple listeners", func() {
   855  			addListener(":443", &ln1)
   856  			addListener(":8443", &ln2)
   857  			checkSetHeaders(Or(
   858  				Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}),
   859  				Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}),
   860  			))
   861  			removeListener(&ln1)
   862  			removeListener(&ln2)
   863  			checkSetHeaderError()
   864  		})
   865  
   866  		It("doesn't duplicate Alt-Svc values", func() {
   867  			s.QUICConfig.Versions = []quic.Version{quic.Version1, quic.Version1}
   868  			addListener(":443", &ln1)
   869  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}}))
   870  			removeListener(&ln1)
   871  			checkSetHeaderError()
   872  		})
   873  	})
   874  
   875  	It("errors when ListenAndServe is called with s.TLSConfig nil", func() {
   876  		Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig))
   877  	})
   878  
   879  	It("should nop-Close() when s.server is nil", func() {
   880  		Expect((&Server{}).Close()).To(Succeed())
   881  	})
   882  
   883  	It("errors when ListenAndServeTLS is called after Close", func() {
   884  		serv := &Server{}
   885  		Expect(serv.Close()).To(Succeed())
   886  		Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed))
   887  	})
   888  
   889  	It("handles concurrent Serve and Close", func() {
   890  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   891  		Expect(err).ToNot(HaveOccurred())
   892  		c, err := net.ListenUDP("udp", addr)
   893  		Expect(err).ToNot(HaveOccurred())
   894  		done := make(chan struct{})
   895  		go func() {
   896  			defer GinkgoRecover()
   897  			defer close(done)
   898  			s.Serve(c)
   899  		}()
   900  		runtime.Gosched()
   901  		s.Close()
   902  		Eventually(done).Should(BeClosed())
   903  	})
   904  
   905  	Context("ConfigureTLSConfig", func() {
   906  		It("advertises v1 by default", func() {
   907  			conf := ConfigureTLSConfig(testdata.GetTLSConfig())
   908  			ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.Version{quic.Version1}})
   909  			Expect(err).ToNot(HaveOccurred())
   910  			defer ln.Close()
   911  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
   912  			Expect(err).ToNot(HaveOccurred())
   913  			defer c.CloseWithError(0, "")
   914  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
   915  		})
   916  
   917  		It("sets the GetConfigForClient callback if no tls.Config is given", func() {
   918  			var receivedConf *tls.Config
   919  			quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) {
   920  				receivedConf = tlsConf
   921  				return nil, errors.New("listen err")
   922  			}
   923  			Expect(s.ListenAndServe()).To(HaveOccurred())
   924  			Expect(receivedConf).ToNot(BeNil())
   925  		})
   926  
   927  		It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() {
   928  			tlsConf := &tls.Config{
   929  				GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
   930  					c := testdata.GetTLSConfig()
   931  					c.NextProtos = []string{"foo", "bar"}
   932  					return c, nil
   933  				},
   934  			}
   935  
   936  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
   937  			Expect(err).ToNot(HaveOccurred())
   938  			defer ln.Close()
   939  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
   940  			Expect(err).ToNot(HaveOccurred())
   941  			defer c.CloseWithError(0, "")
   942  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
   943  		})
   944  
   945  		It("works if GetConfigForClient returns a nil tls.Config", func() {
   946  			tlsConf := testdata.GetTLSConfig()
   947  			tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }
   948  
   949  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
   950  			Expect(err).ToNot(HaveOccurred())
   951  			defer ln.Close()
   952  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
   953  			Expect(err).ToNot(HaveOccurred())
   954  			defer c.CloseWithError(0, "")
   955  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
   956  		})
   957  
   958  		It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
   959  			tlsClientConf := testdata.GetTLSConfig()
   960  			tlsClientConf.NextProtos = []string{"foo", "bar"}
   961  			tlsConf := &tls.Config{
   962  				GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
   963  					return tlsClientConf, nil
   964  				},
   965  			}
   966  
   967  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
   968  			Expect(err).ToNot(HaveOccurred())
   969  			defer ln.Close()
   970  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
   971  			Expect(err).ToNot(HaveOccurred())
   972  			defer c.CloseWithError(0, "")
   973  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
   974  			// check that the original config was not modified
   975  			Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"}))
   976  		})
   977  	})
   978  
   979  	Context("Serve", func() {
   980  		origQuicListen := quicListen
   981  
   982  		AfterEach(func() {
   983  			quicListen = origQuicListen
   984  		})
   985  
   986  		It("serves a packet conn", func() {
   987  			ln := newMockAddrListener(":443")
   988  			conn := &net.UDPConn{}
   989  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
   990  				Expect(c).To(Equal(conn))
   991  				return ln, nil
   992  			}
   993  
   994  			s := &Server{
   995  				TLSConfig: &tls.Config{},
   996  			}
   997  
   998  			stopAccept := make(chan struct{})
   999  			ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1000  				<-stopAccept
  1001  				return nil, errors.New("closed")
  1002  			})
  1003  			ln.EXPECT().Addr() // generate alt-svc headers
  1004  			done := make(chan struct{})
  1005  			go func() {
  1006  				defer GinkgoRecover()
  1007  				defer close(done)
  1008  				s.Serve(conn)
  1009  			}()
  1010  
  1011  			Consistently(done).ShouldNot(BeClosed())
  1012  			ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil })
  1013  			Expect(s.Close()).To(Succeed())
  1014  			Eventually(done).Should(BeClosed())
  1015  		})
  1016  
  1017  		It("serves two packet conns", func() {
  1018  			ln1 := newMockAddrListener(":443")
  1019  			ln2 := newMockAddrListener(":8443")
  1020  			lns := make(chan QUICEarlyListener, 2)
  1021  			lns <- ln1
  1022  			lns <- ln2
  1023  			conn1 := &net.UDPConn{}
  1024  			conn2 := &net.UDPConn{}
  1025  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1026  				return <-lns, nil
  1027  			}
  1028  
  1029  			s := &Server{
  1030  				TLSConfig: &tls.Config{},
  1031  			}
  1032  
  1033  			stopAccept1 := make(chan struct{})
  1034  			ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1035  				<-stopAccept1
  1036  				return nil, errors.New("closed")
  1037  			})
  1038  			ln1.EXPECT().Addr() // generate alt-svc headers
  1039  			stopAccept2 := make(chan struct{})
  1040  			ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1041  				<-stopAccept2
  1042  				return nil, errors.New("closed")
  1043  			})
  1044  			ln2.EXPECT().Addr()
  1045  
  1046  			done1 := make(chan struct{})
  1047  			go func() {
  1048  				defer GinkgoRecover()
  1049  				defer close(done1)
  1050  				s.Serve(conn1)
  1051  			}()
  1052  			done2 := make(chan struct{})
  1053  			go func() {
  1054  				defer GinkgoRecover()
  1055  				defer close(done2)
  1056  				s.Serve(conn2)
  1057  			}()
  1058  
  1059  			Consistently(done1).ShouldNot(BeClosed())
  1060  			Expect(done2).ToNot(BeClosed())
  1061  			ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil })
  1062  			ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil })
  1063  			Expect(s.Close()).To(Succeed())
  1064  			Eventually(done1).Should(BeClosed())
  1065  			Eventually(done2).Should(BeClosed())
  1066  		})
  1067  	})
  1068  
  1069  	Context("ServeListener", func() {
  1070  		origQuicListen := quicListen
  1071  
  1072  		AfterEach(func() {
  1073  			quicListen = origQuicListen
  1074  		})
  1075  
  1076  		It("serves a listener", func() {
  1077  			var called int32
  1078  			ln := newMockAddrListener(":443")
  1079  			quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1080  				atomic.StoreInt32(&called, 1)
  1081  				return ln, nil
  1082  			}
  1083  
  1084  			s := &Server{}
  1085  
  1086  			stopAccept := make(chan struct{})
  1087  			ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1088  				<-stopAccept
  1089  				return nil, errors.New("closed")
  1090  			})
  1091  			ln.EXPECT().Addr() // generate alt-svc headers
  1092  			done := make(chan struct{})
  1093  			go func() {
  1094  				defer GinkgoRecover()
  1095  				defer close(done)
  1096  				s.ServeListener(ln)
  1097  			}()
  1098  
  1099  			Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
  1100  			Consistently(done).ShouldNot(BeClosed())
  1101  			ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil })
  1102  			Expect(s.Close()).To(Succeed())
  1103  			Eventually(done).Should(BeClosed())
  1104  		})
  1105  
  1106  		It("serves two listeners", func() {
  1107  			var called int32
  1108  			ln1 := newMockAddrListener(":443")
  1109  			ln2 := newMockAddrListener(":8443")
  1110  			lns := make(chan QUICEarlyListener, 2)
  1111  			lns <- ln1
  1112  			lns <- ln2
  1113  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1114  				atomic.StoreInt32(&called, 1)
  1115  				return <-lns, nil
  1116  			}
  1117  
  1118  			s := &Server{}
  1119  
  1120  			stopAccept1 := make(chan struct{})
  1121  			ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1122  				<-stopAccept1
  1123  				return nil, errors.New("closed")
  1124  			})
  1125  			ln1.EXPECT().Addr() // generate alt-svc headers
  1126  			stopAccept2 := make(chan struct{})
  1127  			ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1128  				<-stopAccept2
  1129  				return nil, errors.New("closed")
  1130  			})
  1131  			ln2.EXPECT().Addr()
  1132  
  1133  			done1 := make(chan struct{})
  1134  			go func() {
  1135  				defer GinkgoRecover()
  1136  				defer close(done1)
  1137  				s.ServeListener(ln1)
  1138  			}()
  1139  			done2 := make(chan struct{})
  1140  			go func() {
  1141  				defer GinkgoRecover()
  1142  				defer close(done2)
  1143  				s.ServeListener(ln2)
  1144  			}()
  1145  
  1146  			Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
  1147  			Consistently(done1).ShouldNot(BeClosed())
  1148  			Expect(done2).ToNot(BeClosed())
  1149  			ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil })
  1150  			ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil })
  1151  			Expect(s.Close()).To(Succeed())
  1152  			Eventually(done1).Should(BeClosed())
  1153  			Eventually(done2).Should(BeClosed())
  1154  		})
  1155  	})
  1156  
  1157  	Context("ServeQUICConn", func() {
  1158  		It("serves a QUIC connection", func() {
  1159  			mux := http.NewServeMux()
  1160  			mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) {
  1161  				w.Write([]byte("foobar"))
  1162  			})
  1163  			s.Handler = mux
  1164  			tlsConf := testdata.GetTLSConfig()
  1165  			tlsConf.NextProtos = []string{NextProtoH3}
  1166  			conn := mockquic.NewMockEarlyConnection(mockCtrl)
  1167  			controlStr := mockquic.NewMockStream(mockCtrl)
  1168  			controlStr.EXPECT().Write(gomock.Any())
  1169  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
  1170  			testDone := make(chan struct{})
  1171  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
  1172  				<-testDone
  1173  				return nil, errors.New("test done")
  1174  			}).MaxTimes(1)
  1175  			conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)})
  1176  			s.ServeQUICConn(conn)
  1177  			close(testDone)
  1178  		})
  1179  	})
  1180  
  1181  	Context("ListenAndServe", func() {
  1182  		BeforeEach(func() {
  1183  			s.Addr = "localhost:0"
  1184  		})
  1185  
  1186  		AfterEach(func() {
  1187  			Expect(s.Close()).To(Succeed())
  1188  		})
  1189  
  1190  		It("uses the quic.Config to start the QUIC server", func() {
  1191  			conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond}
  1192  			var receivedConf *quic.Config
  1193  			quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1194  				receivedConf = config
  1195  				return nil, errors.New("listen err")
  1196  			}
  1197  			s.QUICConfig = conf
  1198  			Expect(s.ListenAndServe()).To(HaveOccurred())
  1199  			Expect(receivedConf).To(Equal(conf))
  1200  		})
  1201  	})
  1202  
  1203  	It("closes gracefully", func() {
  1204  		Expect(s.CloseGracefully(0)).To(Succeed())
  1205  	})
  1206  
  1207  	It("errors when listening fails", func() {
  1208  		testErr := errors.New("listen error")
  1209  		quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1210  			return nil, testErr
  1211  		}
  1212  		fullpem, privkey := testdata.GetCertificatePaths()
  1213  		Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr))
  1214  	})
  1215  
  1216  	It("supports H3_DATAGRAM", func() {
  1217  		s.EnableDatagrams = true
  1218  		var receivedConf *quic.Config
  1219  		quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1220  			receivedConf = config
  1221  			return nil, errors.New("listen err")
  1222  		}
  1223  		Expect(s.ListenAndServe()).To(HaveOccurred())
  1224  		Expect(receivedConf.EnableDatagrams).To(BeTrue())
  1225  	})
  1226  })