github.com/MerlinKodo/quic-go@v0.39.2/http3/server_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"runtime"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/MerlinKodo/quic-go"
    17  	mockquic "github.com/MerlinKodo/quic-go/internal/mocks/quic"
    18  	"github.com/MerlinKodo/quic-go/internal/protocol"
    19  	"github.com/MerlinKodo/quic-go/internal/testdata"
    20  	"github.com/MerlinKodo/quic-go/internal/utils"
    21  	"github.com/MerlinKodo/quic-go/quicvarint"
    22  
    23  	"github.com/quic-go/qpack"
    24  	"go.uber.org/mock/gomock"
    25  
    26  	. "github.com/onsi/ginkgo/v2"
    27  	. "github.com/onsi/gomega"
    28  	gmtypes "github.com/onsi/gomega/types"
    29  )
    30  
    31  type mockAddr struct{ addr string }
    32  
    33  func (ma *mockAddr) Network() string { return "udp" }
    34  func (ma *mockAddr) String() string  { return ma.addr }
    35  
    36  type mockAddrListener struct {
    37  	*MockQUICEarlyListener
    38  	addr *mockAddr
    39  }
    40  
    41  func (m *mockAddrListener) Addr() net.Addr {
    42  	_ = m.MockQUICEarlyListener.Addr()
    43  	return m.addr
    44  }
    45  
    46  func newMockAddrListener(addr string) *mockAddrListener {
    47  	return &mockAddrListener{
    48  		MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl),
    49  		addr:                  &mockAddr{addr: addr},
    50  	}
    51  }
    52  
    53  type noPortListener struct {
    54  	*mockAddrListener
    55  }
    56  
    57  func (m *noPortListener) Addr() net.Addr {
    58  	_ = m.mockAddrListener.Addr()
    59  	return &net.UnixAddr{
    60  		Net:  "unix",
    61  		Name: "/tmp/quic.sock",
    62  	}
    63  }
    64  
    65  var _ = Describe("Server", func() {
    66  	var (
    67  		s                  *Server
    68  		origQuicListenAddr = quicListenAddr
    69  	)
    70  
    71  	BeforeEach(func() {
    72  		s = &Server{
    73  			TLSConfig: testdata.GetTLSConfig(),
    74  			logger:    utils.DefaultLogger,
    75  		}
    76  		origQuicListenAddr = quicListenAddr
    77  	})
    78  
    79  	AfterEach(func() {
    80  		quicListenAddr = origQuicListenAddr
    81  	})
    82  
    83  	Context("handling requests", func() {
    84  		var (
    85  			qpackDecoder       *qpack.Decoder
    86  			str                *mockquic.MockStream
    87  			conn               *mockquic.MockEarlyConnection
    88  			exampleGetRequest  *http.Request
    89  			examplePostRequest *http.Request
    90  		)
    91  		reqContext := context.Background()
    92  
    93  		decodeHeader := func(str io.Reader) map[string][]string {
    94  			fields := make(map[string][]string)
    95  			decoder := qpack.NewDecoder(nil)
    96  
    97  			frame, err := parseNextFrame(str, nil)
    98  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
    99  			ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
   100  			headersFrame := frame.(*headersFrame)
   101  			data := make([]byte, headersFrame.Length)
   102  			_, err = io.ReadFull(str, data)
   103  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   104  			hfs, err := decoder.DecodeFull(data)
   105  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   106  			for _, p := range hfs {
   107  				fields[p.Name] = append(fields[p.Name], p.Value)
   108  			}
   109  			return fields
   110  		}
   111  
   112  		encodeRequest := func(req *http.Request) []byte {
   113  			buf := &bytes.Buffer{}
   114  			str := mockquic.NewMockStream(mockCtrl)
   115  			str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   116  			rw := newRequestWriter(utils.DefaultLogger)
   117  			Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
   118  			return buf.Bytes()
   119  		}
   120  
   121  		setRequest := func(data []byte) {
   122  			buf := bytes.NewBuffer(data)
   123  			str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   124  				if buf.Len() == 0 {
   125  					return 0, io.EOF
   126  				}
   127  				return buf.Read(p)
   128  			}).AnyTimes()
   129  		}
   130  
   131  		BeforeEach(func() {
   132  			var err error
   133  			exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil)
   134  			Expect(err).ToNot(HaveOccurred())
   135  			examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar")))
   136  			Expect(err).ToNot(HaveOccurred())
   137  
   138  			qpackDecoder = qpack.NewDecoder(nil)
   139  			str = mockquic.NewMockStream(mockCtrl)
   140  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   141  			addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   142  			conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
   143  			conn.EXPECT().LocalAddr().AnyTimes()
   144  			conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
   145  		})
   146  
   147  		It("calls the HTTP handler function", func() {
   148  			requestChan := make(chan *http.Request, 1)
   149  			s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
   150  				requestChan <- r
   151  			})
   152  
   153  			setRequest(encodeRequest(exampleGetRequest))
   154  			str.EXPECT().Context().Return(reqContext)
   155  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   156  				return len(p), nil
   157  			}).AnyTimes()
   158  			str.EXPECT().CancelRead(gomock.Any())
   159  
   160  			Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{}))
   161  			var req *http.Request
   162  			Eventually(requestChan).Should(Receive(&req))
   163  			Expect(req.Host).To(Equal("www.example.com"))
   164  			Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337"))
   165  			Expect(req.Context().Value(ServerContextKey)).To(Equal(s))
   166  		})
   167  
   168  		It("returns 200 with an empty handler", func() {
   169  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   170  
   171  			responseBuf := &bytes.Buffer{}
   172  			setRequest(encodeRequest(exampleGetRequest))
   173  			str.EXPECT().Context().Return(reqContext)
   174  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   175  			str.EXPECT().CancelRead(gomock.Any())
   176  
   177  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   178  			Expect(serr.err).ToNot(HaveOccurred())
   179  			hfs := decodeHeader(responseBuf)
   180  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   181  		})
   182  
   183  		It("sets Content-Length when the handler doesn't flush to the client", func() {
   184  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   185  				w.Write([]byte("foobar"))
   186  			})
   187  
   188  			responseBuf := &bytes.Buffer{}
   189  			setRequest(encodeRequest(exampleGetRequest))
   190  			str.EXPECT().Context().Return(reqContext)
   191  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   192  			str.EXPECT().CancelRead(gomock.Any())
   193  
   194  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   195  			Expect(serr.err).ToNot(HaveOccurred())
   196  			hfs := decodeHeader(responseBuf)
   197  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   198  			Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"}))
   199  			// status, content-length, date, content-type
   200  			Expect(hfs).To(HaveLen(4))
   201  		})
   202  
   203  		It("not sets Content-Length when the handler flushes to the client", func() {
   204  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   205  				w.Write([]byte("foobar"))
   206  				// force flush
   207  				w.(http.Flusher).Flush()
   208  			})
   209  
   210  			responseBuf := &bytes.Buffer{}
   211  			setRequest(encodeRequest(exampleGetRequest))
   212  			str.EXPECT().Context().Return(reqContext)
   213  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   214  			str.EXPECT().CancelRead(gomock.Any())
   215  
   216  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   217  			Expect(serr.err).ToNot(HaveOccurred())
   218  			hfs := decodeHeader(responseBuf)
   219  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   220  			// status, date, content-type
   221  			Expect(hfs).To(HaveLen(3))
   222  		})
   223  
   224  		It("handles a aborting handler", func() {
   225  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   226  				panic(http.ErrAbortHandler)
   227  			})
   228  
   229  			responseBuf := &bytes.Buffer{}
   230  			setRequest(encodeRequest(exampleGetRequest))
   231  			str.EXPECT().Context().Return(reqContext)
   232  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   233  			str.EXPECT().CancelRead(gomock.Any())
   234  
   235  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   236  			Expect(serr.err).ToNot(HaveOccurred())
   237  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   238  		})
   239  
   240  		It("handles a panicking handler", func() {
   241  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   242  				panic("foobar")
   243  			})
   244  
   245  			responseBuf := &bytes.Buffer{}
   246  			setRequest(encodeRequest(exampleGetRequest))
   247  			str.EXPECT().Context().Return(reqContext)
   248  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   249  			str.EXPECT().CancelRead(gomock.Any())
   250  
   251  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   252  			Expect(serr.err).ToNot(HaveOccurred())
   253  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   254  		})
   255  
   256  		Context("hijacking bidirectional streams", func() {
   257  			var conn *mockquic.MockEarlyConnection
   258  			testDone := make(chan struct{})
   259  
   260  			BeforeEach(func() {
   261  				testDone = make(chan struct{})
   262  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   263  				controlStr := mockquic.NewMockStream(mockCtrl)
   264  				controlStr.EXPECT().Write(gomock.Any())
   265  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   266  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   267  				conn.EXPECT().LocalAddr().AnyTimes()
   268  			})
   269  
   270  			AfterEach(func() { testDone <- struct{}{} })
   271  
   272  			It("hijacks a bidirectional stream of unknown frame type", func() {
   273  				frameTypeChan := make(chan FrameType, 1)
   274  				s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   275  					Expect(e).ToNot(HaveOccurred())
   276  					frameTypeChan <- ft
   277  					return true, nil
   278  				}
   279  
   280  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   281  				unknownStr := mockquic.NewMockStream(mockCtrl)
   282  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   283  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   284  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   285  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   286  					<-testDone
   287  					return nil, errors.New("test done")
   288  				})
   289  				s.handleConn(conn)
   290  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   291  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   292  			})
   293  
   294  			It("cancels writing when hijacker didn't hijack a bidirectional stream", func() {
   295  				frameTypeChan := make(chan FrameType, 1)
   296  				s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   297  					Expect(e).ToNot(HaveOccurred())
   298  					frameTypeChan <- ft
   299  					return false, nil
   300  				}
   301  
   302  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   303  				unknownStr := mockquic.NewMockStream(mockCtrl)
   304  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   305  				unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   306  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   307  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   308  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   309  					<-testDone
   310  					return nil, errors.New("test done")
   311  				})
   312  				s.handleConn(conn)
   313  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   314  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   315  			})
   316  
   317  			It("cancels writing when hijacker returned error", func() {
   318  				frameTypeChan := make(chan FrameType, 1)
   319  				s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   320  					Expect(e).ToNot(HaveOccurred())
   321  					frameTypeChan <- ft
   322  					return false, errors.New("error in hijacker")
   323  				}
   324  
   325  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   326  				unknownStr := mockquic.NewMockStream(mockCtrl)
   327  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   328  				unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   329  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   330  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   331  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   332  					<-testDone
   333  					return nil, errors.New("test done")
   334  				})
   335  				s.handleConn(conn)
   336  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   337  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   338  			})
   339  
   340  			It("handles errors that occur when reading the stream type", func() {
   341  				testErr := errors.New("test error")
   342  				done := make(chan struct{})
   343  				unknownStr := mockquic.NewMockStream(mockCtrl)
   344  				s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) {
   345  					defer close(done)
   346  					Expect(ft).To(BeZero())
   347  					Expect(str).To(Equal(unknownStr))
   348  					Expect(err).To(MatchError(testErr))
   349  					return true, nil
   350  				}
   351  
   352  				unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
   353  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   354  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   355  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   356  					<-testDone
   357  					return nil, errors.New("test done")
   358  				})
   359  				s.handleConn(conn)
   360  				Eventually(done).Should(BeClosed())
   361  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   362  			})
   363  		})
   364  
   365  		Context("hijacking unidirectional streams", func() {
   366  			var conn *mockquic.MockEarlyConnection
   367  			testDone := make(chan struct{})
   368  
   369  			BeforeEach(func() {
   370  				testDone = make(chan struct{})
   371  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   372  				controlStr := mockquic.NewMockStream(mockCtrl)
   373  				controlStr.EXPECT().Write(gomock.Any())
   374  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   375  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   376  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   377  				conn.EXPECT().LocalAddr().AnyTimes()
   378  			})
   379  
   380  			AfterEach(func() { testDone <- struct{}{} })
   381  
   382  			It("hijacks an unidirectional stream of unknown stream type", func() {
   383  				streamTypeChan := make(chan StreamType, 1)
   384  				s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   385  					Expect(err).ToNot(HaveOccurred())
   386  					streamTypeChan <- st
   387  					return true
   388  				}
   389  
   390  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   391  				unknownStr := mockquic.NewMockStream(mockCtrl)
   392  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   393  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   394  					return unknownStr, nil
   395  				})
   396  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   397  					<-testDone
   398  					return nil, errors.New("test done")
   399  				})
   400  				s.handleConn(conn)
   401  				Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   402  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   403  			})
   404  
   405  			It("handles errors that occur when reading the stream type", func() {
   406  				testErr := errors.New("test error")
   407  				done := make(chan struct{})
   408  				unknownStr := mockquic.NewMockStream(mockCtrl)
   409  				s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
   410  					defer close(done)
   411  					Expect(st).To(BeZero())
   412  					Expect(str).To(Equal(unknownStr))
   413  					Expect(err).To(MatchError(testErr))
   414  					return true
   415  				}
   416  
   417  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr })
   418  				conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
   419  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   420  					<-testDone
   421  					return nil, errors.New("test done")
   422  				})
   423  				s.handleConn(conn)
   424  				Eventually(done).Should(BeClosed())
   425  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   426  			})
   427  
   428  			It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
   429  				streamTypeChan := make(chan StreamType, 1)
   430  				s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   431  					Expect(err).ToNot(HaveOccurred())
   432  					streamTypeChan <- st
   433  					return false
   434  				}
   435  
   436  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   437  				unknownStr := mockquic.NewMockStream(mockCtrl)
   438  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   439  				unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   440  
   441  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   442  					return unknownStr, nil
   443  				})
   444  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   445  					<-testDone
   446  					return nil, errors.New("test done")
   447  				})
   448  				s.handleConn(conn)
   449  				Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   450  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   451  			})
   452  		})
   453  
   454  		Context("control stream handling", func() {
   455  			var conn *mockquic.MockEarlyConnection
   456  			testDone := make(chan struct{})
   457  
   458  			BeforeEach(func() {
   459  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   460  				controlStr := mockquic.NewMockStream(mockCtrl)
   461  				controlStr.EXPECT().Write(gomock.Any())
   462  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   463  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   464  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   465  				conn.EXPECT().LocalAddr().AnyTimes()
   466  			})
   467  
   468  			AfterEach(func() { testDone <- struct{}{} })
   469  
   470  			It("parses the SETTINGS frame", func() {
   471  				b := quicvarint.Append(nil, streamTypeControlStream)
   472  				b = (&settingsFrame{}).Append(b)
   473  				controlStr := mockquic.NewMockStream(mockCtrl)
   474  				r := bytes.NewReader(b)
   475  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   476  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   477  					return controlStr, nil
   478  				})
   479  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   480  					<-testDone
   481  					return nil, errors.New("test done")
   482  				})
   483  				s.handleConn(conn)
   484  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   485  			})
   486  
   487  			for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
   488  				streamType := t
   489  				name := "encoder"
   490  				if streamType == streamTypeQPACKDecoderStream {
   491  					name = "decoder"
   492  				}
   493  
   494  				It(fmt.Sprintf("ignores the QPACK %s streams", name), func() {
   495  					buf := bytes.NewBuffer(quicvarint.Append(nil, streamType))
   496  					str := mockquic.NewMockStream(mockCtrl)
   497  					str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   498  
   499  					conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   500  						return str, nil
   501  					})
   502  					conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   503  						<-testDone
   504  						return nil, errors.New("test done")
   505  					})
   506  					s.handleConn(conn)
   507  					time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
   508  				})
   509  			}
   510  
   511  			It("reset streams other than the control stream and the QPACK streams", func() {
   512  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0o1337))
   513  				str := mockquic.NewMockStream(mockCtrl)
   514  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   515  				done := make(chan struct{})
   516  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(code quic.StreamErrorCode) {
   517  					close(done)
   518  				})
   519  
   520  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   521  					return str, nil
   522  				})
   523  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   524  					<-testDone
   525  					return nil, errors.New("test done")
   526  				})
   527  				s.handleConn(conn)
   528  				Eventually(done).Should(BeClosed())
   529  			})
   530  
   531  			It("errors when the first frame on the control stream is not a SETTINGS frame", func() {
   532  				b := quicvarint.Append(nil, streamTypeControlStream)
   533  				b = (&dataFrame{}).Append(b)
   534  				controlStr := mockquic.NewMockStream(mockCtrl)
   535  				r := bytes.NewReader(b)
   536  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   537  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   538  					return controlStr, nil
   539  				})
   540  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   541  					<-testDone
   542  					return nil, errors.New("test done")
   543  				})
   544  				done := make(chan struct{})
   545  				conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
   546  					defer GinkgoRecover()
   547  					Expect(code).To(BeEquivalentTo(ErrCodeMissingSettings))
   548  					close(done)
   549  				})
   550  				s.handleConn(conn)
   551  				Eventually(done).Should(BeClosed())
   552  			})
   553  
   554  			It("errors when parsing the frame on the control stream fails", func() {
   555  				b := quicvarint.Append(nil, streamTypeControlStream)
   556  				b = (&settingsFrame{}).Append(b)
   557  				r := bytes.NewReader(b[:len(b)-1])
   558  				controlStr := mockquic.NewMockStream(mockCtrl)
   559  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   560  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   561  					return controlStr, nil
   562  				})
   563  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   564  					<-testDone
   565  					return nil, errors.New("test done")
   566  				})
   567  				done := make(chan struct{})
   568  				conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
   569  					defer GinkgoRecover()
   570  					Expect(code).To(BeEquivalentTo(ErrCodeFrameError))
   571  					close(done)
   572  				})
   573  				s.handleConn(conn)
   574  				Eventually(done).Should(BeClosed())
   575  			})
   576  
   577  			It("errors when the client opens a push stream", func() {
   578  				b := quicvarint.Append(nil, streamTypePushStream)
   579  				b = (&dataFrame{}).Append(b)
   580  				r := bytes.NewReader(b)
   581  				controlStr := mockquic.NewMockStream(mockCtrl)
   582  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   583  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   584  					return controlStr, nil
   585  				})
   586  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   587  					<-testDone
   588  					return nil, errors.New("test done")
   589  				})
   590  				done := make(chan struct{})
   591  				conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
   592  					defer GinkgoRecover()
   593  					Expect(code).To(BeEquivalentTo(ErrCodeStreamCreationError))
   594  					close(done)
   595  				})
   596  				s.handleConn(conn)
   597  				Eventually(done).Should(BeClosed())
   598  			})
   599  
   600  			It("errors when the client advertises datagram support (and we enabled support for it)", func() {
   601  				s.EnableDatagrams = true
   602  				b := quicvarint.Append(nil, streamTypeControlStream)
   603  				b = (&settingsFrame{Datagram: true}).Append(b)
   604  				r := bytes.NewReader(b)
   605  				controlStr := mockquic.NewMockStream(mockCtrl)
   606  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   607  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   608  					return controlStr, nil
   609  				})
   610  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   611  					<-testDone
   612  					return nil, errors.New("test done")
   613  				})
   614  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
   615  				done := make(chan struct{})
   616  				conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) {
   617  					defer GinkgoRecover()
   618  					Expect(code).To(BeEquivalentTo(ErrCodeSettingsError))
   619  					Expect(reason).To(Equal("missing QUIC Datagram support"))
   620  					close(done)
   621  				})
   622  				s.handleConn(conn)
   623  				Eventually(done).Should(BeClosed())
   624  			})
   625  		})
   626  
   627  		Context("stream- and connection-level errors", func() {
   628  			var conn *mockquic.MockEarlyConnection
   629  			testDone := make(chan struct{})
   630  
   631  			BeforeEach(func() {
   632  				testDone = make(chan struct{})
   633  				addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   634  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   635  				controlStr := mockquic.NewMockStream(mockCtrl)
   636  				controlStr.EXPECT().Write(gomock.Any())
   637  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   638  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   639  					<-testDone
   640  					return nil, errors.New("test done")
   641  				})
   642  				conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
   643  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   644  				conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
   645  				conn.EXPECT().LocalAddr().AnyTimes()
   646  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
   647  			})
   648  
   649  			AfterEach(func() { testDone <- struct{}{} })
   650  
   651  			It("cancels reading when client sends a body in GET request", func() {
   652  				var handlerCalled bool
   653  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   654  					handlerCalled = true
   655  				})
   656  
   657  				requestData := encodeRequest(exampleGetRequest)
   658  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   659  				b = append(b, []byte("foobar")...)
   660  				responseBuf := &bytes.Buffer{}
   661  				setRequest(append(requestData, b...))
   662  				done := make(chan struct{})
   663  				str.EXPECT().Context().Return(reqContext)
   664  				str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   665  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   666  				str.EXPECT().Close().Do(func() { close(done) })
   667  
   668  				s.handleConn(conn)
   669  				Eventually(done).Should(BeClosed())
   670  				hfs := decodeHeader(responseBuf)
   671  				Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   672  				Expect(handlerCalled).To(BeTrue())
   673  			})
   674  
   675  			It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() {
   676  				handlerCalled := make(chan struct{})
   677  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   678  					defer close(handlerCalled)
   679  					r.Body.(HTTPStreamer).HTTPStream()
   680  					str.Write([]byte("foobar"))
   681  				})
   682  
   683  				requestData := encodeRequest(exampleGetRequest)
   684  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   685  				b = append(b, []byte("foobar")...)
   686  				setRequest(append(requestData, b...))
   687  				str.EXPECT().Context().Return(reqContext)
   688  				str.EXPECT().Write([]byte("foobar")).Return(6, nil)
   689  
   690  				s.handleConn(conn)
   691  				Eventually(handlerCalled).Should(BeClosed())
   692  			})
   693  
   694  			It("errors when the client sends a too large header frame", func() {
   695  				s.MaxHeaderBytes = 20
   696  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   697  					Fail("Handler should not be called.")
   698  				})
   699  
   700  				requestData := encodeRequest(exampleGetRequest)
   701  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   702  				b = append(b, []byte("foobar")...)
   703  				responseBuf := &bytes.Buffer{}
   704  				setRequest(append(requestData, b...))
   705  				done := make(chan struct{})
   706  				str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   707  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
   708  
   709  				s.handleConn(conn)
   710  				Eventually(done).Should(BeClosed())
   711  			})
   712  
   713  			It("handles a request for which the client immediately resets the stream", func() {
   714  				handlerCalled := make(chan struct{})
   715  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   716  					close(handlerCalled)
   717  				})
   718  
   719  				testErr := errors.New("stream reset")
   720  				done := make(chan struct{})
   721  				str.EXPECT().Read(gomock.Any()).Return(0, testErr)
   722  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
   723  
   724  				s.handleConn(conn)
   725  				Consistently(handlerCalled).ShouldNot(BeClosed())
   726  			})
   727  
   728  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   729  				handlerCalled := make(chan struct{})
   730  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   731  					close(handlerCalled)
   732  				})
   733  
   734  				b := (&dataFrame{}).Append(nil)
   735  				setRequest(b)
   736  				str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   737  					return len(p), nil
   738  				}).AnyTimes()
   739  
   740  				done := make(chan struct{})
   741  				conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
   742  					Expect(code).To(Equal(quic.ApplicationErrorCode(ErrCodeFrameUnexpected)))
   743  					close(done)
   744  				})
   745  				s.handleConn(conn)
   746  				Eventually(done).Should(BeClosed())
   747  			})
   748  
   749  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   750  				handlerCalled := make(chan struct{})
   751  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   752  					close(handlerCalled)
   753  				})
   754  
   755  				// use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest,
   756  				// but the request will still end up larger than DefaultMaxHeaderBytes.
   757  				url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2)
   758  				req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil)
   759  				Expect(err).ToNot(HaveOccurred())
   760  				setRequest(encodeRequest(req))
   761  				// str.EXPECT().Context().Return(reqContext)
   762  				str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   763  					return len(p), nil
   764  				}).AnyTimes()
   765  				done := make(chan struct{})
   766  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
   767  
   768  				s.handleConn(conn)
   769  				Eventually(done).Should(BeClosed())
   770  			})
   771  		})
   772  
   773  		It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
   774  			handlerCalled := make(chan struct{})
   775  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   776  				r.Body = struct {
   777  					io.Reader
   778  					io.Closer
   779  				}{}
   780  				close(handlerCalled)
   781  			})
   782  
   783  			setRequest(encodeRequest(examplePostRequest))
   784  			str.EXPECT().Context().Return(reqContext)
   785  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   786  				return len(p), nil
   787  			}).AnyTimes()
   788  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   789  
   790  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   791  			Expect(serr.err).ToNot(HaveOccurred())
   792  			Eventually(handlerCalled).Should(BeClosed())
   793  		})
   794  
   795  		It("cancels the request context when the stream is closed", func() {
   796  			handlerCalled := make(chan struct{})
   797  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   798  				defer GinkgoRecover()
   799  				Expect(r.Context().Done()).To(BeClosed())
   800  				Expect(r.Context().Err()).To(MatchError(context.Canceled))
   801  				close(handlerCalled)
   802  			})
   803  			setRequest(encodeRequest(examplePostRequest))
   804  
   805  			reqContext, cancel := context.WithCancel(context.Background())
   806  			cancel()
   807  			str.EXPECT().Context().Return(reqContext)
   808  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   809  				return len(p), nil
   810  			}).AnyTimes()
   811  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   812  
   813  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   814  			Expect(serr.err).ToNot(HaveOccurred())
   815  			Eventually(handlerCalled).Should(BeClosed())
   816  		})
   817  	})
   818  
   819  	Context("setting http headers", func() {
   820  		BeforeEach(func() {
   821  			s.QuicConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.Version1}}
   822  		})
   823  
   824  		var ln1 QUICEarlyListener
   825  		var ln2 QUICEarlyListener
   826  		expected := http.Header{
   827  			"Alt-Svc": {`h3=":443"; ma=2592000`},
   828  		}
   829  
   830  		addListener := func(addr string, ln *QUICEarlyListener) {
   831  			mln := newMockAddrListener(addr)
   832  			mln.EXPECT().Addr()
   833  			*ln = mln
   834  			s.addListener(ln)
   835  		}
   836  
   837  		removeListener := func(ln *QUICEarlyListener) {
   838  			s.removeListener(ln)
   839  		}
   840  
   841  		checkSetHeaders := func(expected gmtypes.GomegaMatcher) {
   842  			hdr := http.Header{}
   843  			Expect(s.SetQuicHeaders(hdr)).To(Succeed())
   844  			Expect(hdr).To(expected)
   845  		}
   846  
   847  		checkSetHeaderError := func() {
   848  			hdr := http.Header{}
   849  			Expect(s.SetQuicHeaders(hdr)).To(Equal(ErrNoAltSvcPort))
   850  		}
   851  
   852  		It("sets proper headers with numeric port", func() {
   853  			addListener(":443", &ln1)
   854  			checkSetHeaders(Equal(expected))
   855  			removeListener(&ln1)
   856  			checkSetHeaderError()
   857  		})
   858  
   859  		It("sets proper headers with full addr", func() {
   860  			addListener("127.0.0.1:443", &ln1)
   861  			checkSetHeaders(Equal(expected))
   862  			removeListener(&ln1)
   863  			checkSetHeaderError()
   864  		})
   865  
   866  		It("sets proper headers with string port", func() {
   867  			addListener(":https", &ln1)
   868  			checkSetHeaders(Equal(expected))
   869  			removeListener(&ln1)
   870  			checkSetHeaderError()
   871  		})
   872  
   873  		It("works multiple times", func() {
   874  			addListener(":https", &ln1)
   875  			checkSetHeaders(Equal(expected))
   876  			checkSetHeaders(Equal(expected))
   877  			removeListener(&ln1)
   878  			checkSetHeaderError()
   879  		})
   880  
   881  		It("works if the quic.Config sets QUIC versions", func() {
   882  			s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.Version2}
   883  			addListener(":443", &ln1)
   884  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}}))
   885  			removeListener(&ln1)
   886  			checkSetHeaderError()
   887  		})
   888  
   889  		It("uses s.Port if set to a non-zero value", func() {
   890  			s.Port = 8443
   891  			addListener(":443", &ln1)
   892  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}}))
   893  			removeListener(&ln1)
   894  			checkSetHeaderError()
   895  		})
   896  
   897  		It("uses s.Addr if listeners don't have ports available", func() {
   898  			s.Addr = ":443"
   899  			mln := &noPortListener{newMockAddrListener("")}
   900  			mln.EXPECT().Addr()
   901  			ln1 = mln
   902  			s.addListener(&ln1)
   903  			checkSetHeaders(Equal(expected))
   904  			s.removeListener(&ln1)
   905  			checkSetHeaderError()
   906  		})
   907  
   908  		It("properly announces multiple listeners", func() {
   909  			addListener(":443", &ln1)
   910  			addListener(":8443", &ln2)
   911  			checkSetHeaders(Or(
   912  				Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}),
   913  				Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}),
   914  			))
   915  			removeListener(&ln1)
   916  			removeListener(&ln2)
   917  			checkSetHeaderError()
   918  		})
   919  
   920  		It("doesn't duplicate Alt-Svc values", func() {
   921  			s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.Version1}
   922  			addListener(":443", &ln1)
   923  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}}))
   924  			removeListener(&ln1)
   925  			checkSetHeaderError()
   926  		})
   927  	})
   928  
   929  	It("errors when ListenAndServe is called with s.TLSConfig nil", func() {
   930  		Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig))
   931  	})
   932  
   933  	It("should nop-Close() when s.server is nil", func() {
   934  		Expect((&Server{}).Close()).To(Succeed())
   935  	})
   936  
   937  	It("errors when ListenAndServeTLS is called after Close", func() {
   938  		serv := &Server{}
   939  		Expect(serv.Close()).To(Succeed())
   940  		Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed))
   941  	})
   942  
   943  	It("handles concurrent Serve and Close", func() {
   944  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   945  		Expect(err).ToNot(HaveOccurred())
   946  		c, err := net.ListenUDP("udp", addr)
   947  		Expect(err).ToNot(HaveOccurred())
   948  		done := make(chan struct{})
   949  		go func() {
   950  			defer GinkgoRecover()
   951  			defer close(done)
   952  			s.Serve(c)
   953  		}()
   954  		runtime.Gosched()
   955  		s.Close()
   956  		Eventually(done).Should(BeClosed())
   957  	})
   958  
   959  	Context("ConfigureTLSConfig", func() {
   960  		It("advertises v1 by default", func() {
   961  			conf := ConfigureTLSConfig(testdata.GetTLSConfig())
   962  			ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
   963  			Expect(err).ToNot(HaveOccurred())
   964  			defer ln.Close()
   965  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
   966  			Expect(err).ToNot(HaveOccurred())
   967  			defer c.CloseWithError(0, "")
   968  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
   969  		})
   970  
   971  		It("sets the GetConfigForClient callback if no tls.Config is given", func() {
   972  			var receivedConf *tls.Config
   973  			quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) {
   974  				receivedConf = tlsConf
   975  				return nil, errors.New("listen err")
   976  			}
   977  			Expect(s.ListenAndServe()).To(HaveOccurred())
   978  			Expect(receivedConf).ToNot(BeNil())
   979  		})
   980  
   981  		It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() {
   982  			tlsConf := &tls.Config{
   983  				GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
   984  					c := testdata.GetTLSConfig()
   985  					c.NextProtos = []string{"foo", "bar"}
   986  					return c, nil
   987  				},
   988  			}
   989  
   990  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
   991  			Expect(err).ToNot(HaveOccurred())
   992  			defer ln.Close()
   993  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
   994  			Expect(err).ToNot(HaveOccurred())
   995  			defer c.CloseWithError(0, "")
   996  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
   997  		})
   998  
   999  		It("works if GetConfigForClient returns a nil tls.Config", func() {
  1000  			tlsConf := testdata.GetTLSConfig()
  1001  			tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }
  1002  
  1003  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
  1004  			Expect(err).ToNot(HaveOccurred())
  1005  			defer ln.Close()
  1006  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
  1007  			Expect(err).ToNot(HaveOccurred())
  1008  			defer c.CloseWithError(0, "")
  1009  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
  1010  		})
  1011  
  1012  		It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
  1013  			tlsClientConf := testdata.GetTLSConfig()
  1014  			tlsClientConf.NextProtos = []string{"foo", "bar"}
  1015  			tlsConf := &tls.Config{
  1016  				GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
  1017  					return tlsClientConf, nil
  1018  				},
  1019  			}
  1020  
  1021  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
  1022  			Expect(err).ToNot(HaveOccurred())
  1023  			defer ln.Close()
  1024  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
  1025  			Expect(err).ToNot(HaveOccurred())
  1026  			defer c.CloseWithError(0, "")
  1027  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
  1028  			// check that the original config was not modified
  1029  			Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"}))
  1030  		})
  1031  	})
  1032  
  1033  	Context("Serve", func() {
  1034  		origQuicListen := quicListen
  1035  
  1036  		AfterEach(func() {
  1037  			quicListen = origQuicListen
  1038  		})
  1039  
  1040  		It("serves a packet conn", func() {
  1041  			ln := newMockAddrListener(":443")
  1042  			conn := &net.UDPConn{}
  1043  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1044  				Expect(c).To(Equal(conn))
  1045  				return ln, nil
  1046  			}
  1047  
  1048  			s := &Server{
  1049  				TLSConfig: &tls.Config{},
  1050  			}
  1051  
  1052  			stopAccept := make(chan struct{})
  1053  			ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) {
  1054  				<-stopAccept
  1055  				return nil, errors.New("closed")
  1056  			})
  1057  			ln.EXPECT().Addr() // generate alt-svc headers
  1058  			done := make(chan struct{})
  1059  			go func() {
  1060  				defer GinkgoRecover()
  1061  				defer close(done)
  1062  				s.Serve(conn)
  1063  			}()
  1064  
  1065  			Consistently(done).ShouldNot(BeClosed())
  1066  			ln.EXPECT().Close().Do(func() { close(stopAccept) })
  1067  			Expect(s.Close()).To(Succeed())
  1068  			Eventually(done).Should(BeClosed())
  1069  		})
  1070  
  1071  		It("serves two packet conns", func() {
  1072  			ln1 := newMockAddrListener(":443")
  1073  			ln2 := newMockAddrListener(":8443")
  1074  			lns := make(chan QUICEarlyListener, 2)
  1075  			lns <- ln1
  1076  			lns <- ln2
  1077  			conn1 := &net.UDPConn{}
  1078  			conn2 := &net.UDPConn{}
  1079  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1080  				return <-lns, nil
  1081  			}
  1082  
  1083  			s := &Server{
  1084  				TLSConfig: &tls.Config{},
  1085  			}
  1086  
  1087  			stopAccept1 := make(chan struct{})
  1088  			ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) {
  1089  				<-stopAccept1
  1090  				return nil, errors.New("closed")
  1091  			})
  1092  			ln1.EXPECT().Addr() // generate alt-svc headers
  1093  			stopAccept2 := make(chan struct{})
  1094  			ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) {
  1095  				<-stopAccept2
  1096  				return nil, errors.New("closed")
  1097  			})
  1098  			ln2.EXPECT().Addr()
  1099  
  1100  			done1 := make(chan struct{})
  1101  			go func() {
  1102  				defer GinkgoRecover()
  1103  				defer close(done1)
  1104  				s.Serve(conn1)
  1105  			}()
  1106  			done2 := make(chan struct{})
  1107  			go func() {
  1108  				defer GinkgoRecover()
  1109  				defer close(done2)
  1110  				s.Serve(conn2)
  1111  			}()
  1112  
  1113  			Consistently(done1).ShouldNot(BeClosed())
  1114  			Expect(done2).ToNot(BeClosed())
  1115  			ln1.EXPECT().Close().Do(func() { close(stopAccept1) })
  1116  			ln2.EXPECT().Close().Do(func() { close(stopAccept2) })
  1117  			Expect(s.Close()).To(Succeed())
  1118  			Eventually(done1).Should(BeClosed())
  1119  			Eventually(done2).Should(BeClosed())
  1120  		})
  1121  	})
  1122  
  1123  	Context("ServeListener", func() {
  1124  		origQuicListen := quicListen
  1125  
  1126  		AfterEach(func() {
  1127  			quicListen = origQuicListen
  1128  		})
  1129  
  1130  		It("serves a listener", func() {
  1131  			var called int32
  1132  			ln := newMockAddrListener(":443")
  1133  			quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1134  				atomic.StoreInt32(&called, 1)
  1135  				return ln, nil
  1136  			}
  1137  
  1138  			s := &Server{}
  1139  
  1140  			stopAccept := make(chan struct{})
  1141  			ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) {
  1142  				<-stopAccept
  1143  				return nil, errors.New("closed")
  1144  			})
  1145  			ln.EXPECT().Addr() // generate alt-svc headers
  1146  			done := make(chan struct{})
  1147  			go func() {
  1148  				defer GinkgoRecover()
  1149  				defer close(done)
  1150  				s.ServeListener(ln)
  1151  			}()
  1152  
  1153  			Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
  1154  			Consistently(done).ShouldNot(BeClosed())
  1155  			ln.EXPECT().Close().Do(func() { close(stopAccept) })
  1156  			Expect(s.Close()).To(Succeed())
  1157  			Eventually(done).Should(BeClosed())
  1158  		})
  1159  
  1160  		It("serves two listeners", func() {
  1161  			var called int32
  1162  			ln1 := newMockAddrListener(":443")
  1163  			ln2 := newMockAddrListener(":8443")
  1164  			lns := make(chan QUICEarlyListener, 2)
  1165  			lns <- ln1
  1166  			lns <- ln2
  1167  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1168  				atomic.StoreInt32(&called, 1)
  1169  				return <-lns, nil
  1170  			}
  1171  
  1172  			s := &Server{}
  1173  
  1174  			stopAccept1 := make(chan struct{})
  1175  			ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) {
  1176  				<-stopAccept1
  1177  				return nil, errors.New("closed")
  1178  			})
  1179  			ln1.EXPECT().Addr() // generate alt-svc headers
  1180  			stopAccept2 := make(chan struct{})
  1181  			ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) {
  1182  				<-stopAccept2
  1183  				return nil, errors.New("closed")
  1184  			})
  1185  			ln2.EXPECT().Addr()
  1186  
  1187  			done1 := make(chan struct{})
  1188  			go func() {
  1189  				defer GinkgoRecover()
  1190  				defer close(done1)
  1191  				s.ServeListener(ln1)
  1192  			}()
  1193  			done2 := make(chan struct{})
  1194  			go func() {
  1195  				defer GinkgoRecover()
  1196  				defer close(done2)
  1197  				s.ServeListener(ln2)
  1198  			}()
  1199  
  1200  			Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
  1201  			Consistently(done1).ShouldNot(BeClosed())
  1202  			Expect(done2).ToNot(BeClosed())
  1203  			ln1.EXPECT().Close().Do(func() { close(stopAccept1) })
  1204  			ln2.EXPECT().Close().Do(func() { close(stopAccept2) })
  1205  			Expect(s.Close()).To(Succeed())
  1206  			Eventually(done1).Should(BeClosed())
  1207  			Eventually(done2).Should(BeClosed())
  1208  		})
  1209  	})
  1210  
  1211  	Context("ServeQUICConn", func() {
  1212  		It("serves a QUIC connection", func() {
  1213  			mux := http.NewServeMux()
  1214  			mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) {
  1215  				w.Write([]byte("foobar"))
  1216  			})
  1217  			s.Handler = mux
  1218  			tlsConf := testdata.GetTLSConfig()
  1219  			tlsConf.NextProtos = []string{NextProtoH3}
  1220  			conn := mockquic.NewMockEarlyConnection(mockCtrl)
  1221  			controlStr := mockquic.NewMockStream(mockCtrl)
  1222  			controlStr.EXPECT().Write(gomock.Any())
  1223  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
  1224  			testDone := make(chan struct{})
  1225  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
  1226  				<-testDone
  1227  				return nil, errors.New("test done")
  1228  			}).MaxTimes(1)
  1229  			conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)})
  1230  			s.ServeQUICConn(conn)
  1231  			close(testDone)
  1232  		})
  1233  	})
  1234  
  1235  	Context("ListenAndServe", func() {
  1236  		BeforeEach(func() {
  1237  			s.Addr = "localhost:0"
  1238  		})
  1239  
  1240  		AfterEach(func() {
  1241  			Expect(s.Close()).To(Succeed())
  1242  		})
  1243  
  1244  		It("uses the quic.Config to start the QUIC server", func() {
  1245  			conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond}
  1246  			var receivedConf *quic.Config
  1247  			quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1248  				receivedConf = config
  1249  				return nil, errors.New("listen err")
  1250  			}
  1251  			s.QuicConfig = conf
  1252  			Expect(s.ListenAndServe()).To(HaveOccurred())
  1253  			Expect(receivedConf).To(Equal(conf))
  1254  		})
  1255  	})
  1256  
  1257  	It("closes gracefully", func() {
  1258  		Expect(s.CloseGracefully(0)).To(Succeed())
  1259  	})
  1260  
  1261  	It("errors when listening fails", func() {
  1262  		testErr := errors.New("listen error")
  1263  		quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1264  			return nil, testErr
  1265  		}
  1266  		fullpem, privkey := testdata.GetCertificatePaths()
  1267  		Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr))
  1268  	})
  1269  
  1270  	It("supports H3_DATAGRAM", func() {
  1271  		s.EnableDatagrams = true
  1272  		var receivedConf *quic.Config
  1273  		quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1274  			receivedConf = config
  1275  			return nil, errors.New("listen err")
  1276  		}
  1277  		Expect(s.ListenAndServe()).To(HaveOccurred())
  1278  		Expect(receivedConf.EnableDatagrams).To(BeTrue())
  1279  	})
  1280  })