github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/server_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"runtime"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/daeuniverse/quic-go"
    17  	mockquic "github.com/daeuniverse/quic-go/internal/mocks/quic"
    18  	"github.com/daeuniverse/quic-go/internal/protocol"
    19  	"github.com/daeuniverse/quic-go/internal/qerr"
    20  	"github.com/daeuniverse/quic-go/internal/testdata"
    21  	"github.com/daeuniverse/quic-go/internal/utils"
    22  	"github.com/daeuniverse/quic-go/quicvarint"
    23  
    24  	"github.com/quic-go/qpack"
    25  	"go.uber.org/mock/gomock"
    26  
    27  	. "github.com/onsi/ginkgo/v2"
    28  	. "github.com/onsi/gomega"
    29  	gmtypes "github.com/onsi/gomega/types"
    30  )
    31  
    32  type mockAddr struct{ addr string }
    33  
    34  func (ma *mockAddr) Network() string { return "udp" }
    35  func (ma *mockAddr) String() string  { return ma.addr }
    36  
    37  type mockAddrListener struct {
    38  	*MockQUICEarlyListener
    39  	addr *mockAddr
    40  }
    41  
    42  func (m *mockAddrListener) Addr() net.Addr {
    43  	_ = m.MockQUICEarlyListener.Addr()
    44  	return m.addr
    45  }
    46  
    47  func newMockAddrListener(addr string) *mockAddrListener {
    48  	return &mockAddrListener{
    49  		MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl),
    50  		addr:                  &mockAddr{addr: addr},
    51  	}
    52  }
    53  
    54  type noPortListener struct {
    55  	*mockAddrListener
    56  }
    57  
    58  func (m *noPortListener) Addr() net.Addr {
    59  	_ = m.mockAddrListener.Addr()
    60  	return &net.UnixAddr{
    61  		Net:  "unix",
    62  		Name: "/tmp/quic.sock",
    63  	}
    64  }
    65  
    66  var _ = Describe("Server", func() {
    67  	var (
    68  		s                  *Server
    69  		origQuicListenAddr = quicListenAddr
    70  	)
    71  	type testConnContextKey string
    72  
    73  	BeforeEach(func() {
    74  		s = &Server{
    75  			TLSConfig: testdata.GetTLSConfig(),
    76  			logger:    utils.DefaultLogger,
    77  			ConnContext: func(ctx context.Context, c quic.Connection) context.Context {
    78  				return context.WithValue(ctx, testConnContextKey("test"), c)
    79  			},
    80  		}
    81  		origQuicListenAddr = quicListenAddr
    82  	})
    83  
    84  	AfterEach(func() {
    85  		quicListenAddr = origQuicListenAddr
    86  	})
    87  
    88  	Context("handling requests", func() {
    89  		var (
    90  			qpackDecoder       *qpack.Decoder
    91  			str                *mockquic.MockStream
    92  			conn               *mockquic.MockEarlyConnection
    93  			exampleGetRequest  *http.Request
    94  			examplePostRequest *http.Request
    95  		)
    96  		reqContext := context.Background()
    97  
    98  		decodeHeader := func(str io.Reader) map[string][]string {
    99  			fields := make(map[string][]string)
   100  			decoder := qpack.NewDecoder(nil)
   101  
   102  			frame, err := parseNextFrame(str, nil)
   103  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   104  			ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
   105  			headersFrame := frame.(*headersFrame)
   106  			data := make([]byte, headersFrame.Length)
   107  			_, err = io.ReadFull(str, data)
   108  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   109  			hfs, err := decoder.DecodeFull(data)
   110  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   111  			for _, p := range hfs {
   112  				fields[p.Name] = append(fields[p.Name], p.Value)
   113  			}
   114  			return fields
   115  		}
   116  
   117  		encodeRequest := func(req *http.Request) []byte {
   118  			buf := &bytes.Buffer{}
   119  			str := mockquic.NewMockStream(mockCtrl)
   120  			str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   121  			rw := newRequestWriter(utils.DefaultLogger)
   122  			Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
   123  			return buf.Bytes()
   124  		}
   125  
   126  		setRequest := func(data []byte) {
   127  			buf := bytes.NewBuffer(data)
   128  			str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   129  				if buf.Len() == 0 {
   130  					return 0, io.EOF
   131  				}
   132  				return buf.Read(p)
   133  			}).AnyTimes()
   134  		}
   135  
   136  		BeforeEach(func() {
   137  			var err error
   138  			exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil)
   139  			Expect(err).ToNot(HaveOccurred())
   140  			examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar")))
   141  			Expect(err).ToNot(HaveOccurred())
   142  
   143  			qpackDecoder = qpack.NewDecoder(nil)
   144  			str = mockquic.NewMockStream(mockCtrl)
   145  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   146  			addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   147  			conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
   148  			conn.EXPECT().LocalAddr().AnyTimes()
   149  			conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
   150  		})
   151  
   152  		It("calls the HTTP handler function", func() {
   153  			requestChan := make(chan *http.Request, 1)
   154  			s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
   155  				requestChan <- r
   156  			})
   157  
   158  			setRequest(encodeRequest(exampleGetRequest))
   159  			str.EXPECT().Context().Return(reqContext)
   160  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   161  				return len(p), nil
   162  			}).AnyTimes()
   163  			str.EXPECT().CancelRead(gomock.Any())
   164  
   165  			Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{}))
   166  			var req *http.Request
   167  			Eventually(requestChan).Should(Receive(&req))
   168  			Expect(req.Host).To(Equal("www.example.com"))
   169  			Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337"))
   170  			Expect(req.Context().Value(ServerContextKey)).To(Equal(s))
   171  			Expect(req.Context().Value(testConnContextKey("test"))).ToNot(Equal(nil))
   172  		})
   173  
   174  		It("returns 200 with an empty handler", func() {
   175  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   176  
   177  			responseBuf := &bytes.Buffer{}
   178  			setRequest(encodeRequest(exampleGetRequest))
   179  			str.EXPECT().Context().Return(reqContext)
   180  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   181  			str.EXPECT().CancelRead(gomock.Any())
   182  
   183  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   184  			Expect(serr.err).ToNot(HaveOccurred())
   185  			hfs := decodeHeader(responseBuf)
   186  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   187  		})
   188  
   189  		It("sets Content-Length when the handler doesn't flush to the client", func() {
   190  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   191  				w.Write([]byte("foobar"))
   192  			})
   193  
   194  			responseBuf := &bytes.Buffer{}
   195  			setRequest(encodeRequest(exampleGetRequest))
   196  			str.EXPECT().Context().Return(reqContext)
   197  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   198  			str.EXPECT().CancelRead(gomock.Any())
   199  
   200  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   201  			Expect(serr.err).ToNot(HaveOccurred())
   202  			hfs := decodeHeader(responseBuf)
   203  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   204  			Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"}))
   205  			// status, content-length, date, content-type
   206  			Expect(hfs).To(HaveLen(4))
   207  		})
   208  
   209  		It("not sets Content-Length when the handler flushes to the client", func() {
   210  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   211  				w.Write([]byte("foobar"))
   212  				// force flush
   213  				w.(http.Flusher).Flush()
   214  			})
   215  
   216  			responseBuf := &bytes.Buffer{}
   217  			setRequest(encodeRequest(exampleGetRequest))
   218  			str.EXPECT().Context().Return(reqContext)
   219  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   220  			str.EXPECT().CancelRead(gomock.Any())
   221  
   222  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   223  			Expect(serr.err).ToNot(HaveOccurred())
   224  			hfs := decodeHeader(responseBuf)
   225  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   226  			// status, date, content-type
   227  			Expect(hfs).To(HaveLen(3))
   228  		})
   229  
   230  		It("response to HEAD request should not have body", func() {
   231  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   232  				w.Write([]byte("foobar"))
   233  			})
   234  
   235  			headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil)
   236  			Expect(err).ToNot(HaveOccurred())
   237  			responseBuf := &bytes.Buffer{}
   238  			setRequest(encodeRequest(headRequest))
   239  			str.EXPECT().Context().Return(reqContext)
   240  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   241  			str.EXPECT().CancelRead(gomock.Any())
   242  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   243  			Expect(serr.err).ToNot(HaveOccurred())
   244  			hfs := decodeHeader(responseBuf)
   245  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   246  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   247  		})
   248  
   249  		It("response to HEAD request should also do content sniffing", func() {
   250  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   251  				w.Write([]byte("<html></html>"))
   252  			})
   253  
   254  			headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil)
   255  			Expect(err).ToNot(HaveOccurred())
   256  			responseBuf := &bytes.Buffer{}
   257  			setRequest(encodeRequest(headRequest))
   258  			str.EXPECT().Context().Return(reqContext)
   259  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   260  			str.EXPECT().CancelRead(gomock.Any())
   261  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   262  			Expect(serr.err).ToNot(HaveOccurred())
   263  			hfs := decodeHeader(responseBuf)
   264  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   265  			Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"}))
   266  			Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"}))
   267  		})
   268  
   269  		It("handles a aborting handler", func() {
   270  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   271  				panic(http.ErrAbortHandler)
   272  			})
   273  
   274  			responseBuf := &bytes.Buffer{}
   275  			setRequest(encodeRequest(exampleGetRequest))
   276  			str.EXPECT().Context().Return(reqContext)
   277  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   278  			str.EXPECT().CancelRead(gomock.Any())
   279  
   280  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   281  			Expect(serr.err).To(MatchError(errPanicked))
   282  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   283  		})
   284  
   285  		It("handles a panicking handler", func() {
   286  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   287  				panic("foobar")
   288  			})
   289  
   290  			responseBuf := &bytes.Buffer{}
   291  			setRequest(encodeRequest(exampleGetRequest))
   292  			str.EXPECT().Context().Return(reqContext)
   293  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   294  			str.EXPECT().CancelRead(gomock.Any())
   295  
   296  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   297  			Expect(serr.err).To(MatchError(errPanicked))
   298  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   299  		})
   300  
   301  		Context("hijacking bidirectional streams", func() {
   302  			var conn *mockquic.MockEarlyConnection
   303  			testDone := make(chan struct{})
   304  
   305  			BeforeEach(func() {
   306  				testDone = make(chan struct{})
   307  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   308  				controlStr := mockquic.NewMockStream(mockCtrl)
   309  				controlStr.EXPECT().Write(gomock.Any())
   310  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   311  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   312  				conn.EXPECT().LocalAddr().AnyTimes()
   313  			})
   314  
   315  			AfterEach(func() { testDone <- struct{}{} })
   316  
   317  			It("hijacks a bidirectional stream of unknown frame type", func() {
   318  				frameTypeChan := make(chan FrameType, 1)
   319  				s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   320  					Expect(e).ToNot(HaveOccurred())
   321  					frameTypeChan <- ft
   322  					return true, nil
   323  				}
   324  
   325  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   326  				unknownStr := mockquic.NewMockStream(mockCtrl)
   327  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   328  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   329  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   330  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   331  					<-testDone
   332  					return nil, errors.New("test done")
   333  				})
   334  				s.handleConn(conn)
   335  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   336  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   337  			})
   338  
   339  			It("cancels writing when hijacker didn't hijack a bidirectional stream", func() {
   340  				frameTypeChan := make(chan FrameType, 1)
   341  				s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   342  					Expect(e).ToNot(HaveOccurred())
   343  					frameTypeChan <- ft
   344  					return false, nil
   345  				}
   346  
   347  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   348  				unknownStr := mockquic.NewMockStream(mockCtrl)
   349  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   350  				unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   351  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   352  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   353  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   354  					<-testDone
   355  					return nil, errors.New("test done")
   356  				})
   357  				s.handleConn(conn)
   358  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   359  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   360  			})
   361  
   362  			It("cancels writing when hijacker returned error", func() {
   363  				frameTypeChan := make(chan FrameType, 1)
   364  				s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   365  					Expect(e).ToNot(HaveOccurred())
   366  					frameTypeChan <- ft
   367  					return false, errors.New("error in hijacker")
   368  				}
   369  
   370  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   371  				unknownStr := mockquic.NewMockStream(mockCtrl)
   372  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   373  				unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   374  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   375  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   376  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   377  					<-testDone
   378  					return nil, errors.New("test done")
   379  				})
   380  				s.handleConn(conn)
   381  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   382  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   383  			})
   384  
   385  			It("handles errors that occur when reading the stream type", func() {
   386  				testErr := errors.New("test error")
   387  				done := make(chan struct{})
   388  				unknownStr := mockquic.NewMockStream(mockCtrl)
   389  				s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) {
   390  					defer close(done)
   391  					Expect(ft).To(BeZero())
   392  					Expect(str).To(Equal(unknownStr))
   393  					Expect(err).To(MatchError(testErr))
   394  					return true, nil
   395  				}
   396  
   397  				unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
   398  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   399  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   400  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   401  					<-testDone
   402  					return nil, errors.New("test done")
   403  				})
   404  				s.handleConn(conn)
   405  				Eventually(done).Should(BeClosed())
   406  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   407  			})
   408  		})
   409  
   410  		Context("hijacking unidirectional streams", func() {
   411  			var conn *mockquic.MockEarlyConnection
   412  			testDone := make(chan struct{})
   413  
   414  			BeforeEach(func() {
   415  				testDone = make(chan struct{})
   416  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   417  				controlStr := mockquic.NewMockStream(mockCtrl)
   418  				controlStr.EXPECT().Write(gomock.Any())
   419  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   420  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   421  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   422  				conn.EXPECT().LocalAddr().AnyTimes()
   423  			})
   424  
   425  			AfterEach(func() { testDone <- struct{}{} })
   426  
   427  			It("hijacks an unidirectional stream of unknown stream type", func() {
   428  				streamTypeChan := make(chan StreamType, 1)
   429  				s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   430  					Expect(err).ToNot(HaveOccurred())
   431  					streamTypeChan <- st
   432  					return true
   433  				}
   434  
   435  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   436  				unknownStr := mockquic.NewMockStream(mockCtrl)
   437  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   438  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   439  					return unknownStr, nil
   440  				})
   441  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   442  					<-testDone
   443  					return nil, errors.New("test done")
   444  				})
   445  				s.handleConn(conn)
   446  				Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   447  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   448  			})
   449  
   450  			It("handles errors that occur when reading the stream type", func() {
   451  				testErr := errors.New("test error")
   452  				done := make(chan struct{})
   453  				unknownStr := mockquic.NewMockStream(mockCtrl)
   454  				s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
   455  					defer close(done)
   456  					Expect(st).To(BeZero())
   457  					Expect(str).To(Equal(unknownStr))
   458  					Expect(err).To(MatchError(testErr))
   459  					return true
   460  				}
   461  
   462  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr })
   463  				conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
   464  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   465  					<-testDone
   466  					return nil, errors.New("test done")
   467  				})
   468  				s.handleConn(conn)
   469  				Eventually(done).Should(BeClosed())
   470  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   471  			})
   472  
   473  			It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
   474  				streamTypeChan := make(chan StreamType, 1)
   475  				s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   476  					Expect(err).ToNot(HaveOccurred())
   477  					streamTypeChan <- st
   478  					return false
   479  				}
   480  
   481  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   482  				unknownStr := mockquic.NewMockStream(mockCtrl)
   483  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   484  				unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   485  
   486  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   487  					return unknownStr, nil
   488  				})
   489  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   490  					<-testDone
   491  					return nil, errors.New("test done")
   492  				})
   493  				s.handleConn(conn)
   494  				Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   495  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   496  			})
   497  		})
   498  
   499  		Context("control stream handling", func() {
   500  			var conn *mockquic.MockEarlyConnection
   501  			testDone := make(chan struct{}, 1)
   502  
   503  			BeforeEach(func() {
   504  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   505  				controlStr := mockquic.NewMockStream(mockCtrl)
   506  				controlStr.EXPECT().Write(gomock.Any())
   507  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   508  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   509  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   510  				conn.EXPECT().LocalAddr().AnyTimes()
   511  			})
   512  
   513  			AfterEach(func() { testDone <- struct{}{} })
   514  
   515  			It("parses the SETTINGS frame", func() {
   516  				b := quicvarint.Append(nil, streamTypeControlStream)
   517  				b = (&settingsFrame{}).Append(b)
   518  				controlStr := mockquic.NewMockStream(mockCtrl)
   519  				r := bytes.NewReader(b)
   520  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   521  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   522  					return controlStr, nil
   523  				})
   524  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   525  					<-testDone
   526  					return nil, errors.New("test done")
   527  				})
   528  				s.handleConn(conn)
   529  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   530  			})
   531  
   532  			It("rejects duplicate control streams", func() {
   533  				b := quicvarint.Append(nil, streamTypeControlStream)
   534  				b = (&settingsFrame{}).Append(b)
   535  				r1 := bytes.NewReader(b)
   536  				controlStr1 := mockquic.NewMockStream(mockCtrl)
   537  				controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(r1.Read).AnyTimes()
   538  				r2 := bytes.NewReader(b)
   539  				controlStr2 := mockquic.NewMockStream(mockCtrl)
   540  				controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes()
   541  				done := make(chan struct{})
   542  				conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error {
   543  					close(done)
   544  					return nil
   545  				})
   546  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   547  					return controlStr1, nil
   548  				})
   549  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   550  					return controlStr2, nil
   551  				})
   552  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   553  					<-done
   554  					return nil, errors.New("test done")
   555  				})
   556  				s.handleConn(conn)
   557  				Eventually(done).Should(BeClosed())
   558  			})
   559  
   560  			for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
   561  				streamType := t
   562  				name := "encoder"
   563  				if streamType == streamTypeQPACKDecoderStream {
   564  					name = "decoder"
   565  				}
   566  
   567  				It(fmt.Sprintf("ignores the QPACK %s streams", name), func() {
   568  					buf := bytes.NewBuffer(quicvarint.Append(nil, streamType))
   569  					str := mockquic.NewMockStream(mockCtrl)
   570  					str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   571  
   572  					conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   573  						return str, nil
   574  					})
   575  					conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   576  						<-testDone
   577  						return nil, errors.New("test done")
   578  					})
   579  					s.handleConn(conn)
   580  					time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
   581  				})
   582  			}
   583  
   584  			It("reset streams other than the control stream and the QPACK streams", func() {
   585  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0o1337))
   586  				str := mockquic.NewMockStream(mockCtrl)
   587  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   588  				done := make(chan struct{})
   589  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) })
   590  
   591  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   592  					return str, nil
   593  				})
   594  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   595  					<-testDone
   596  					return nil, errors.New("test done")
   597  				})
   598  				s.handleConn(conn)
   599  				Eventually(done).Should(BeClosed())
   600  			})
   601  
   602  			It("errors when the first frame on the control stream is not a SETTINGS frame", func() {
   603  				b := quicvarint.Append(nil, streamTypeControlStream)
   604  				b = (&dataFrame{}).Append(b)
   605  				controlStr := mockquic.NewMockStream(mockCtrl)
   606  				r := bytes.NewReader(b)
   607  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   608  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   609  					return controlStr, nil
   610  				})
   611  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   612  					<-testDone
   613  					return nil, errors.New("test done")
   614  				})
   615  				done := make(chan struct{})
   616  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   617  					close(done)
   618  					return nil
   619  				})
   620  				s.handleConn(conn)
   621  				Eventually(done).Should(BeClosed())
   622  			})
   623  
   624  			It("errors when parsing the frame on the control stream fails", func() {
   625  				b := quicvarint.Append(nil, streamTypeControlStream)
   626  				b = (&settingsFrame{}).Append(b)
   627  				r := bytes.NewReader(b[:len(b)-1])
   628  				controlStr := mockquic.NewMockStream(mockCtrl)
   629  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   630  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   631  					return controlStr, nil
   632  				})
   633  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   634  					<-testDone
   635  					return nil, errors.New("test done")
   636  				})
   637  				done := make(chan struct{})
   638  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   639  					close(done)
   640  					return nil
   641  				})
   642  				s.handleConn(conn)
   643  				Eventually(done).Should(BeClosed())
   644  			})
   645  
   646  			It("errors when the client opens a push stream", func() {
   647  				b := quicvarint.Append(nil, streamTypePushStream)
   648  				b = (&dataFrame{}).Append(b)
   649  				r := bytes.NewReader(b)
   650  				controlStr := mockquic.NewMockStream(mockCtrl)
   651  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   652  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   653  					return controlStr, nil
   654  				})
   655  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   656  					<-testDone
   657  					return nil, errors.New("test done")
   658  				})
   659  				done := make(chan struct{})
   660  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   661  					close(done)
   662  					return nil
   663  				})
   664  				s.handleConn(conn)
   665  				Eventually(done).Should(BeClosed())
   666  			})
   667  
   668  			It("errors when the client advertises datagram support (and we enabled support for it)", func() {
   669  				s.EnableDatagrams = true
   670  				b := quicvarint.Append(nil, streamTypeControlStream)
   671  				b = (&settingsFrame{Datagram: true}).Append(b)
   672  				r := bytes.NewReader(b)
   673  				controlStr := mockquic.NewMockStream(mockCtrl)
   674  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   675  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   676  					return controlStr, nil
   677  				})
   678  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   679  					<-testDone
   680  					return nil, errors.New("test done")
   681  				})
   682  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
   683  				done := make(chan struct{})
   684  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error {
   685  					close(done)
   686  					return nil
   687  				})
   688  				s.handleConn(conn)
   689  				Eventually(done).Should(BeClosed())
   690  			})
   691  		})
   692  
   693  		Context("stream- and connection-level errors", func() {
   694  			var conn *mockquic.MockEarlyConnection
   695  			testDone := make(chan struct{})
   696  
   697  			BeforeEach(func() {
   698  				testDone = make(chan struct{})
   699  				addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   700  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   701  				controlStr := mockquic.NewMockStream(mockCtrl)
   702  				controlStr.EXPECT().Write(gomock.Any())
   703  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   704  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   705  					<-testDone
   706  					return nil, errors.New("test done")
   707  				})
   708  				conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
   709  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   710  				conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
   711  				conn.EXPECT().LocalAddr().AnyTimes()
   712  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
   713  			})
   714  
   715  			AfterEach(func() { testDone <- struct{}{} })
   716  
   717  			It("cancels reading when client sends a body in GET request", func() {
   718  				var handlerCalled bool
   719  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   720  					handlerCalled = true
   721  				})
   722  
   723  				requestData := encodeRequest(exampleGetRequest)
   724  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   725  				b = append(b, []byte("foobar")...)
   726  				responseBuf := &bytes.Buffer{}
   727  				setRequest(append(requestData, b...))
   728  				done := make(chan struct{})
   729  				str.EXPECT().Context().Return(reqContext)
   730  				str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   731  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   732  				str.EXPECT().Close().Do(func() error { close(done); return nil })
   733  
   734  				s.handleConn(conn)
   735  				Eventually(done).Should(BeClosed())
   736  				hfs := decodeHeader(responseBuf)
   737  				Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   738  				Expect(handlerCalled).To(BeTrue())
   739  			})
   740  
   741  			It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() {
   742  				handlerCalled := make(chan struct{})
   743  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   744  					defer close(handlerCalled)
   745  					r.Body.(HTTPStreamer).HTTPStream()
   746  					str.Write([]byte("foobar"))
   747  				})
   748  
   749  				requestData := encodeRequest(exampleGetRequest)
   750  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   751  				b = append(b, []byte("foobar")...)
   752  				setRequest(append(requestData, b...))
   753  				str.EXPECT().Context().Return(reqContext)
   754  				str.EXPECT().Write([]byte("foobar")).Return(6, nil)
   755  
   756  				s.handleConn(conn)
   757  				Eventually(handlerCalled).Should(BeClosed())
   758  			})
   759  
   760  			It("errors when the client sends a too large header frame", func() {
   761  				s.MaxHeaderBytes = 20
   762  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   763  					Fail("Handler should not be called.")
   764  				})
   765  
   766  				requestData := encodeRequest(exampleGetRequest)
   767  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   768  				b = append(b, []byte("foobar")...)
   769  				responseBuf := &bytes.Buffer{}
   770  				setRequest(append(requestData, b...))
   771  				done := make(chan struct{})
   772  				str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   773  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
   774  
   775  				s.handleConn(conn)
   776  				Eventually(done).Should(BeClosed())
   777  			})
   778  
   779  			It("handles a request for which the client immediately resets the stream", func() {
   780  				handlerCalled := make(chan struct{})
   781  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   782  					close(handlerCalled)
   783  				})
   784  
   785  				testErr := errors.New("stream reset")
   786  				done := make(chan struct{})
   787  				str.EXPECT().Read(gomock.Any()).Return(0, testErr)
   788  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
   789  
   790  				s.handleConn(conn)
   791  				Consistently(handlerCalled).ShouldNot(BeClosed())
   792  			})
   793  
   794  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   795  				handlerCalled := make(chan struct{})
   796  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   797  					close(handlerCalled)
   798  				})
   799  
   800  				b := (&dataFrame{}).Append(nil)
   801  				setRequest(b)
   802  				str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   803  					return len(p), nil
   804  				}).AnyTimes()
   805  
   806  				done := make(chan struct{})
   807  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   808  					close(done)
   809  					return nil
   810  				})
   811  				s.handleConn(conn)
   812  				Eventually(done).Should(BeClosed())
   813  			})
   814  
   815  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   816  				handlerCalled := make(chan struct{})
   817  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   818  					close(handlerCalled)
   819  				})
   820  
   821  				// use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest,
   822  				// but the request will still end up larger than DefaultMaxHeaderBytes.
   823  				url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2)
   824  				req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil)
   825  				Expect(err).ToNot(HaveOccurred())
   826  				setRequest(encodeRequest(req))
   827  				// str.EXPECT().Context().Return(reqContext)
   828  				str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   829  					return len(p), nil
   830  				}).AnyTimes()
   831  				done := make(chan struct{})
   832  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
   833  
   834  				s.handleConn(conn)
   835  				Eventually(done).Should(BeClosed())
   836  			})
   837  		})
   838  
   839  		It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
   840  			handlerCalled := make(chan struct{})
   841  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   842  				r.Body = struct {
   843  					io.Reader
   844  					io.Closer
   845  				}{}
   846  				close(handlerCalled)
   847  			})
   848  
   849  			setRequest(encodeRequest(examplePostRequest))
   850  			str.EXPECT().Context().Return(reqContext)
   851  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   852  				return len(p), nil
   853  			}).AnyTimes()
   854  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   855  
   856  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   857  			Expect(serr.err).ToNot(HaveOccurred())
   858  			Eventually(handlerCalled).Should(BeClosed())
   859  		})
   860  
   861  		It("cancels the request context when the stream is closed", func() {
   862  			handlerCalled := make(chan struct{})
   863  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   864  				defer GinkgoRecover()
   865  				Expect(r.Context().Done()).To(BeClosed())
   866  				Expect(r.Context().Err()).To(MatchError(context.Canceled))
   867  				close(handlerCalled)
   868  			})
   869  			setRequest(encodeRequest(examplePostRequest))
   870  
   871  			reqContext, cancel := context.WithCancel(context.Background())
   872  			cancel()
   873  			str.EXPECT().Context().Return(reqContext)
   874  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   875  				return len(p), nil
   876  			}).AnyTimes()
   877  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   878  
   879  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   880  			Expect(serr.err).ToNot(HaveOccurred())
   881  			Eventually(handlerCalled).Should(BeClosed())
   882  		})
   883  	})
   884  
   885  	Context("setting http headers", func() {
   886  		BeforeEach(func() {
   887  			s.QuicConfig = &quic.Config{Versions: []protocol.Version{protocol.Version1}}
   888  		})
   889  
   890  		var ln1 QUICEarlyListener
   891  		var ln2 QUICEarlyListener
   892  		expected := http.Header{
   893  			"Alt-Svc": {`h3=":443"; ma=2592000`},
   894  		}
   895  
   896  		addListener := func(addr string, ln *QUICEarlyListener) {
   897  			mln := newMockAddrListener(addr)
   898  			mln.EXPECT().Addr()
   899  			*ln = mln
   900  			s.addListener(ln)
   901  		}
   902  
   903  		removeListener := func(ln *QUICEarlyListener) {
   904  			s.removeListener(ln)
   905  		}
   906  
   907  		checkSetHeaders := func(expected gmtypes.GomegaMatcher) {
   908  			hdr := http.Header{}
   909  			Expect(s.SetQuicHeaders(hdr)).To(Succeed())
   910  			Expect(hdr).To(expected)
   911  		}
   912  
   913  		checkSetHeaderError := func() {
   914  			hdr := http.Header{}
   915  			Expect(s.SetQuicHeaders(hdr)).To(Equal(ErrNoAltSvcPort))
   916  		}
   917  
   918  		It("sets proper headers with numeric port", func() {
   919  			addListener(":443", &ln1)
   920  			checkSetHeaders(Equal(expected))
   921  			removeListener(&ln1)
   922  			checkSetHeaderError()
   923  		})
   924  
   925  		It("sets proper headers with full addr", func() {
   926  			addListener("127.0.0.1:443", &ln1)
   927  			checkSetHeaders(Equal(expected))
   928  			removeListener(&ln1)
   929  			checkSetHeaderError()
   930  		})
   931  
   932  		It("sets proper headers with string port", func() {
   933  			addListener(":https", &ln1)
   934  			checkSetHeaders(Equal(expected))
   935  			removeListener(&ln1)
   936  			checkSetHeaderError()
   937  		})
   938  
   939  		It("works multiple times", func() {
   940  			addListener(":https", &ln1)
   941  			checkSetHeaders(Equal(expected))
   942  			checkSetHeaders(Equal(expected))
   943  			removeListener(&ln1)
   944  			checkSetHeaderError()
   945  		})
   946  
   947  		It("works if the quic.Config sets QUIC versions", func() {
   948  			s.QuicConfig.Versions = []quic.Version{quic.Version1, quic.Version2}
   949  			addListener(":443", &ln1)
   950  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}}))
   951  			removeListener(&ln1)
   952  			checkSetHeaderError()
   953  		})
   954  
   955  		It("uses s.Port if set to a non-zero value", func() {
   956  			s.Port = 8443
   957  			addListener(":443", &ln1)
   958  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}}))
   959  			removeListener(&ln1)
   960  			checkSetHeaderError()
   961  		})
   962  
   963  		It("uses s.Addr if listeners don't have ports available", func() {
   964  			s.Addr = ":443"
   965  			mln := &noPortListener{newMockAddrListener("")}
   966  			mln.EXPECT().Addr()
   967  			ln1 = mln
   968  			s.addListener(&ln1)
   969  			checkSetHeaders(Equal(expected))
   970  			s.removeListener(&ln1)
   971  			checkSetHeaderError()
   972  		})
   973  
   974  		It("properly announces multiple listeners", func() {
   975  			addListener(":443", &ln1)
   976  			addListener(":8443", &ln2)
   977  			checkSetHeaders(Or(
   978  				Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}),
   979  				Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}),
   980  			))
   981  			removeListener(&ln1)
   982  			removeListener(&ln2)
   983  			checkSetHeaderError()
   984  		})
   985  
   986  		It("doesn't duplicate Alt-Svc values", func() {
   987  			s.QuicConfig.Versions = []quic.Version{quic.Version1, quic.Version1}
   988  			addListener(":443", &ln1)
   989  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}}))
   990  			removeListener(&ln1)
   991  			checkSetHeaderError()
   992  		})
   993  	})
   994  
   995  	It("errors when ListenAndServe is called with s.TLSConfig nil", func() {
   996  		Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig))
   997  	})
   998  
   999  	It("should nop-Close() when s.server is nil", func() {
  1000  		Expect((&Server{}).Close()).To(Succeed())
  1001  	})
  1002  
  1003  	It("errors when ListenAndServeTLS is called after Close", func() {
  1004  		serv := &Server{}
  1005  		Expect(serv.Close()).To(Succeed())
  1006  		Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed))
  1007  	})
  1008  
  1009  	It("handles concurrent Serve and Close", func() {
  1010  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
  1011  		Expect(err).ToNot(HaveOccurred())
  1012  		c, err := net.ListenUDP("udp", addr)
  1013  		Expect(err).ToNot(HaveOccurred())
  1014  		done := make(chan struct{})
  1015  		go func() {
  1016  			defer GinkgoRecover()
  1017  			defer close(done)
  1018  			s.Serve(c)
  1019  		}()
  1020  		runtime.Gosched()
  1021  		s.Close()
  1022  		Eventually(done).Should(BeClosed())
  1023  	})
  1024  
  1025  	Context("ConfigureTLSConfig", func() {
  1026  		It("advertises v1 by default", func() {
  1027  			conf := ConfigureTLSConfig(testdata.GetTLSConfig())
  1028  			ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.Version{quic.Version1}})
  1029  			Expect(err).ToNot(HaveOccurred())
  1030  			defer ln.Close()
  1031  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
  1032  			Expect(err).ToNot(HaveOccurred())
  1033  			defer c.CloseWithError(0, "")
  1034  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
  1035  		})
  1036  
  1037  		It("sets the GetConfigForClient callback if no tls.Config is given", func() {
  1038  			var receivedConf *tls.Config
  1039  			quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) {
  1040  				receivedConf = tlsConf
  1041  				return nil, errors.New("listen err")
  1042  			}
  1043  			Expect(s.ListenAndServe()).To(HaveOccurred())
  1044  			Expect(receivedConf).ToNot(BeNil())
  1045  		})
  1046  
  1047  		It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() {
  1048  			tlsConf := &tls.Config{
  1049  				GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
  1050  					c := testdata.GetTLSConfig()
  1051  					c.NextProtos = []string{"foo", "bar"}
  1052  					return c, nil
  1053  				},
  1054  			}
  1055  
  1056  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
  1057  			Expect(err).ToNot(HaveOccurred())
  1058  			defer ln.Close()
  1059  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
  1060  			Expect(err).ToNot(HaveOccurred())
  1061  			defer c.CloseWithError(0, "")
  1062  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
  1063  		})
  1064  
  1065  		It("works if GetConfigForClient returns a nil tls.Config", func() {
  1066  			tlsConf := testdata.GetTLSConfig()
  1067  			tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }
  1068  
  1069  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
  1070  			Expect(err).ToNot(HaveOccurred())
  1071  			defer ln.Close()
  1072  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
  1073  			Expect(err).ToNot(HaveOccurred())
  1074  			defer c.CloseWithError(0, "")
  1075  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
  1076  		})
  1077  
  1078  		It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
  1079  			tlsClientConf := testdata.GetTLSConfig()
  1080  			tlsClientConf.NextProtos = []string{"foo", "bar"}
  1081  			tlsConf := &tls.Config{
  1082  				GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
  1083  					return tlsClientConf, nil
  1084  				},
  1085  			}
  1086  
  1087  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
  1088  			Expect(err).ToNot(HaveOccurred())
  1089  			defer ln.Close()
  1090  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
  1091  			Expect(err).ToNot(HaveOccurred())
  1092  			defer c.CloseWithError(0, "")
  1093  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
  1094  			// check that the original config was not modified
  1095  			Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"}))
  1096  		})
  1097  	})
  1098  
  1099  	Context("Serve", func() {
  1100  		origQuicListen := quicListen
  1101  
  1102  		AfterEach(func() {
  1103  			quicListen = origQuicListen
  1104  		})
  1105  
  1106  		It("serves a packet conn", func() {
  1107  			ln := newMockAddrListener(":443")
  1108  			conn := &net.UDPConn{}
  1109  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1110  				Expect(c).To(Equal(conn))
  1111  				return ln, nil
  1112  			}
  1113  
  1114  			s := &Server{
  1115  				TLSConfig: &tls.Config{},
  1116  			}
  1117  
  1118  			stopAccept := make(chan struct{})
  1119  			ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1120  				<-stopAccept
  1121  				return nil, errors.New("closed")
  1122  			})
  1123  			ln.EXPECT().Addr() // generate alt-svc headers
  1124  			done := make(chan struct{})
  1125  			go func() {
  1126  				defer GinkgoRecover()
  1127  				defer close(done)
  1128  				s.Serve(conn)
  1129  			}()
  1130  
  1131  			Consistently(done).ShouldNot(BeClosed())
  1132  			ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil })
  1133  			Expect(s.Close()).To(Succeed())
  1134  			Eventually(done).Should(BeClosed())
  1135  		})
  1136  
  1137  		It("serves two packet conns", func() {
  1138  			ln1 := newMockAddrListener(":443")
  1139  			ln2 := newMockAddrListener(":8443")
  1140  			lns := make(chan QUICEarlyListener, 2)
  1141  			lns <- ln1
  1142  			lns <- ln2
  1143  			conn1 := &net.UDPConn{}
  1144  			conn2 := &net.UDPConn{}
  1145  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1146  				return <-lns, nil
  1147  			}
  1148  
  1149  			s := &Server{
  1150  				TLSConfig: &tls.Config{},
  1151  			}
  1152  
  1153  			stopAccept1 := make(chan struct{})
  1154  			ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1155  				<-stopAccept1
  1156  				return nil, errors.New("closed")
  1157  			})
  1158  			ln1.EXPECT().Addr() // generate alt-svc headers
  1159  			stopAccept2 := make(chan struct{})
  1160  			ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1161  				<-stopAccept2
  1162  				return nil, errors.New("closed")
  1163  			})
  1164  			ln2.EXPECT().Addr()
  1165  
  1166  			done1 := make(chan struct{})
  1167  			go func() {
  1168  				defer GinkgoRecover()
  1169  				defer close(done1)
  1170  				s.Serve(conn1)
  1171  			}()
  1172  			done2 := make(chan struct{})
  1173  			go func() {
  1174  				defer GinkgoRecover()
  1175  				defer close(done2)
  1176  				s.Serve(conn2)
  1177  			}()
  1178  
  1179  			Consistently(done1).ShouldNot(BeClosed())
  1180  			Expect(done2).ToNot(BeClosed())
  1181  			ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil })
  1182  			ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil })
  1183  			Expect(s.Close()).To(Succeed())
  1184  			Eventually(done1).Should(BeClosed())
  1185  			Eventually(done2).Should(BeClosed())
  1186  		})
  1187  	})
  1188  
  1189  	Context("ServeListener", func() {
  1190  		origQuicListen := quicListen
  1191  
  1192  		AfterEach(func() {
  1193  			quicListen = origQuicListen
  1194  		})
  1195  
  1196  		It("serves a listener", func() {
  1197  			var called int32
  1198  			ln := newMockAddrListener(":443")
  1199  			quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1200  				atomic.StoreInt32(&called, 1)
  1201  				return ln, nil
  1202  			}
  1203  
  1204  			s := &Server{}
  1205  
  1206  			stopAccept := make(chan struct{})
  1207  			ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1208  				<-stopAccept
  1209  				return nil, errors.New("closed")
  1210  			})
  1211  			ln.EXPECT().Addr() // generate alt-svc headers
  1212  			done := make(chan struct{})
  1213  			go func() {
  1214  				defer GinkgoRecover()
  1215  				defer close(done)
  1216  				s.ServeListener(ln)
  1217  			}()
  1218  
  1219  			Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
  1220  			Consistently(done).ShouldNot(BeClosed())
  1221  			ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil })
  1222  			Expect(s.Close()).To(Succeed())
  1223  			Eventually(done).Should(BeClosed())
  1224  		})
  1225  
  1226  		It("serves two listeners", func() {
  1227  			var called int32
  1228  			ln1 := newMockAddrListener(":443")
  1229  			ln2 := newMockAddrListener(":8443")
  1230  			lns := make(chan QUICEarlyListener, 2)
  1231  			lns <- ln1
  1232  			lns <- ln2
  1233  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1234  				atomic.StoreInt32(&called, 1)
  1235  				return <-lns, nil
  1236  			}
  1237  
  1238  			s := &Server{}
  1239  
  1240  			stopAccept1 := make(chan struct{})
  1241  			ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1242  				<-stopAccept1
  1243  				return nil, errors.New("closed")
  1244  			})
  1245  			ln1.EXPECT().Addr() // generate alt-svc headers
  1246  			stopAccept2 := make(chan struct{})
  1247  			ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1248  				<-stopAccept2
  1249  				return nil, errors.New("closed")
  1250  			})
  1251  			ln2.EXPECT().Addr()
  1252  
  1253  			done1 := make(chan struct{})
  1254  			go func() {
  1255  				defer GinkgoRecover()
  1256  				defer close(done1)
  1257  				s.ServeListener(ln1)
  1258  			}()
  1259  			done2 := make(chan struct{})
  1260  			go func() {
  1261  				defer GinkgoRecover()
  1262  				defer close(done2)
  1263  				s.ServeListener(ln2)
  1264  			}()
  1265  
  1266  			Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
  1267  			Consistently(done1).ShouldNot(BeClosed())
  1268  			Expect(done2).ToNot(BeClosed())
  1269  			ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil })
  1270  			ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil })
  1271  			Expect(s.Close()).To(Succeed())
  1272  			Eventually(done1).Should(BeClosed())
  1273  			Eventually(done2).Should(BeClosed())
  1274  		})
  1275  	})
  1276  
  1277  	Context("ServeQUICConn", func() {
  1278  		It("serves a QUIC connection", func() {
  1279  			mux := http.NewServeMux()
  1280  			mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) {
  1281  				w.Write([]byte("foobar"))
  1282  			})
  1283  			s.Handler = mux
  1284  			tlsConf := testdata.GetTLSConfig()
  1285  			tlsConf.NextProtos = []string{NextProtoH3}
  1286  			conn := mockquic.NewMockEarlyConnection(mockCtrl)
  1287  			controlStr := mockquic.NewMockStream(mockCtrl)
  1288  			controlStr.EXPECT().Write(gomock.Any())
  1289  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
  1290  			testDone := make(chan struct{})
  1291  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
  1292  				<-testDone
  1293  				return nil, errors.New("test done")
  1294  			}).MaxTimes(1)
  1295  			conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)})
  1296  			s.ServeQUICConn(conn)
  1297  			close(testDone)
  1298  		})
  1299  	})
  1300  
  1301  	Context("ListenAndServe", func() {
  1302  		BeforeEach(func() {
  1303  			s.Addr = "localhost:0"
  1304  		})
  1305  
  1306  		AfterEach(func() {
  1307  			Expect(s.Close()).To(Succeed())
  1308  		})
  1309  
  1310  		It("uses the quic.Config to start the QUIC server", func() {
  1311  			conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond}
  1312  			var receivedConf *quic.Config
  1313  			quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1314  				receivedConf = config
  1315  				return nil, errors.New("listen err")
  1316  			}
  1317  			s.QuicConfig = conf
  1318  			Expect(s.ListenAndServe()).To(HaveOccurred())
  1319  			Expect(receivedConf).To(Equal(conf))
  1320  		})
  1321  	})
  1322  
  1323  	It("closes gracefully", func() {
  1324  		Expect(s.CloseGracefully(0)).To(Succeed())
  1325  	})
  1326  
  1327  	It("errors when listening fails", func() {
  1328  		testErr := errors.New("listen error")
  1329  		quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1330  			return nil, testErr
  1331  		}
  1332  		fullpem, privkey := testdata.GetCertificatePaths()
  1333  		Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr))
  1334  	})
  1335  
  1336  	It("supports H3_DATAGRAM", func() {
  1337  		s.EnableDatagrams = true
  1338  		var receivedConf *quic.Config
  1339  		quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1340  			receivedConf = config
  1341  			return nil, errors.New("listen err")
  1342  		}
  1343  		Expect(s.ListenAndServe()).To(HaveOccurred())
  1344  		Expect(receivedConf.EnableDatagrams).To(BeTrue())
  1345  	})
  1346  })