github.com/TugasAkhir-QUIC/quic-go@v0.0.2-0.20240215011318-d20e25a9054c/http3/server_test.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"runtime"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/TugasAkhir-QUIC/quic-go"
    17  	mockquic "github.com/TugasAkhir-QUIC/quic-go/internal/mocks/quic"
    18  	"github.com/TugasAkhir-QUIC/quic-go/internal/protocol"
    19  	"github.com/TugasAkhir-QUIC/quic-go/internal/testdata"
    20  	"github.com/TugasAkhir-QUIC/quic-go/internal/utils"
    21  	"github.com/TugasAkhir-QUIC/quic-go/quicvarint"
    22  
    23  	"github.com/quic-go/qpack"
    24  	"go.uber.org/mock/gomock"
    25  
    26  	. "github.com/onsi/ginkgo/v2"
    27  	. "github.com/onsi/gomega"
    28  	gmtypes "github.com/onsi/gomega/types"
    29  )
    30  
    31  type mockAddr struct{ addr string }
    32  
    33  func (ma *mockAddr) Network() string { return "udp" }
    34  func (ma *mockAddr) String() string  { return ma.addr }
    35  
    36  type mockAddrListener struct {
    37  	*MockQUICEarlyListener
    38  	addr *mockAddr
    39  }
    40  
    41  func (m *mockAddrListener) Addr() net.Addr {
    42  	_ = m.MockQUICEarlyListener.Addr()
    43  	return m.addr
    44  }
    45  
    46  func newMockAddrListener(addr string) *mockAddrListener {
    47  	return &mockAddrListener{
    48  		MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl),
    49  		addr:                  &mockAddr{addr: addr},
    50  	}
    51  }
    52  
    53  type noPortListener struct {
    54  	*mockAddrListener
    55  }
    56  
    57  func (m *noPortListener) Addr() net.Addr {
    58  	_ = m.mockAddrListener.Addr()
    59  	return &net.UnixAddr{
    60  		Net:  "unix",
    61  		Name: "/tmp/quic.sock",
    62  	}
    63  }
    64  
    65  var _ = Describe("Server", func() {
    66  	var (
    67  		s                  *Server
    68  		origQuicListenAddr = quicListenAddr
    69  	)
    70  	type testConnContextKey string
    71  
    72  	BeforeEach(func() {
    73  		s = &Server{
    74  			TLSConfig: testdata.GetTLSConfig(),
    75  			logger:    utils.DefaultLogger,
    76  			ConnContext: func(ctx context.Context, c quic.Connection) context.Context {
    77  				return context.WithValue(ctx, testConnContextKey("test"), c)
    78  			},
    79  		}
    80  		origQuicListenAddr = quicListenAddr
    81  	})
    82  
    83  	AfterEach(func() {
    84  		quicListenAddr = origQuicListenAddr
    85  	})
    86  
    87  	Context("handling requests", func() {
    88  		var (
    89  			qpackDecoder       *qpack.Decoder
    90  			str                *mockquic.MockStream
    91  			conn               *mockquic.MockEarlyConnection
    92  			exampleGetRequest  *http.Request
    93  			examplePostRequest *http.Request
    94  		)
    95  		reqContext := context.Background()
    96  
    97  		decodeHeader := func(str io.Reader) map[string][]string {
    98  			fields := make(map[string][]string)
    99  			decoder := qpack.NewDecoder(nil)
   100  
   101  			frame, err := parseNextFrame(str, nil)
   102  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   103  			ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
   104  			headersFrame := frame.(*headersFrame)
   105  			data := make([]byte, headersFrame.Length)
   106  			_, err = io.ReadFull(str, data)
   107  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   108  			hfs, err := decoder.DecodeFull(data)
   109  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   110  			for _, p := range hfs {
   111  				fields[p.Name] = append(fields[p.Name], p.Value)
   112  			}
   113  			return fields
   114  		}
   115  
   116  		encodeRequest := func(req *http.Request) []byte {
   117  			buf := &bytes.Buffer{}
   118  			str := mockquic.NewMockStream(mockCtrl)
   119  			str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
   120  			rw := newRequestWriter(utils.DefaultLogger)
   121  			Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed())
   122  			return buf.Bytes()
   123  		}
   124  
   125  		setRequest := func(data []byte) {
   126  			buf := bytes.NewBuffer(data)
   127  			str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   128  				if buf.Len() == 0 {
   129  					return 0, io.EOF
   130  				}
   131  				return buf.Read(p)
   132  			}).AnyTimes()
   133  		}
   134  
   135  		BeforeEach(func() {
   136  			var err error
   137  			exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil)
   138  			Expect(err).ToNot(HaveOccurred())
   139  			examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar")))
   140  			Expect(err).ToNot(HaveOccurred())
   141  
   142  			qpackDecoder = qpack.NewDecoder(nil)
   143  			str = mockquic.NewMockStream(mockCtrl)
   144  			conn = mockquic.NewMockEarlyConnection(mockCtrl)
   145  			addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   146  			conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
   147  			conn.EXPECT().LocalAddr().AnyTimes()
   148  			conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
   149  		})
   150  
   151  		It("calls the HTTP handler function", func() {
   152  			requestChan := make(chan *http.Request, 1)
   153  			s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
   154  				requestChan <- r
   155  			})
   156  
   157  			setRequest(encodeRequest(exampleGetRequest))
   158  			str.EXPECT().Context().Return(reqContext)
   159  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   160  				return len(p), nil
   161  			}).AnyTimes()
   162  			str.EXPECT().CancelRead(gomock.Any())
   163  
   164  			Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{}))
   165  			var req *http.Request
   166  			Eventually(requestChan).Should(Receive(&req))
   167  			Expect(req.Host).To(Equal("www.example.com"))
   168  			Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337"))
   169  			Expect(req.Context().Value(ServerContextKey)).To(Equal(s))
   170  			Expect(req.Context().Value(testConnContextKey("test"))).ToNot(Equal(nil))
   171  		})
   172  
   173  		It("returns 200 with an empty handler", func() {
   174  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   175  
   176  			responseBuf := &bytes.Buffer{}
   177  			setRequest(encodeRequest(exampleGetRequest))
   178  			str.EXPECT().Context().Return(reqContext)
   179  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   180  			str.EXPECT().CancelRead(gomock.Any())
   181  
   182  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   183  			Expect(serr.err).ToNot(HaveOccurred())
   184  			hfs := decodeHeader(responseBuf)
   185  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   186  		})
   187  
   188  		It("sets Content-Length when the handler doesn't flush to the client", func() {
   189  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   190  				w.Write([]byte("foobar"))
   191  			})
   192  
   193  			responseBuf := &bytes.Buffer{}
   194  			setRequest(encodeRequest(exampleGetRequest))
   195  			str.EXPECT().Context().Return(reqContext)
   196  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   197  			str.EXPECT().CancelRead(gomock.Any())
   198  
   199  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   200  			Expect(serr.err).ToNot(HaveOccurred())
   201  			hfs := decodeHeader(responseBuf)
   202  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   203  			Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"}))
   204  			// status, content-length, date, content-type
   205  			Expect(hfs).To(HaveLen(4))
   206  		})
   207  
   208  		It("not sets Content-Length when the handler flushes to the client", func() {
   209  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   210  				w.Write([]byte("foobar"))
   211  				// force flush
   212  				w.(http.Flusher).Flush()
   213  			})
   214  
   215  			responseBuf := &bytes.Buffer{}
   216  			setRequest(encodeRequest(exampleGetRequest))
   217  			str.EXPECT().Context().Return(reqContext)
   218  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   219  			str.EXPECT().CancelRead(gomock.Any())
   220  
   221  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   222  			Expect(serr.err).ToNot(HaveOccurred())
   223  			hfs := decodeHeader(responseBuf)
   224  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   225  			// status, date, content-type
   226  			Expect(hfs).To(HaveLen(3))
   227  		})
   228  
   229  		It("response to HEAD request should not have body", func() {
   230  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   231  				w.Write([]byte("foobar"))
   232  			})
   233  
   234  			headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil)
   235  			Expect(err).ToNot(HaveOccurred())
   236  			responseBuf := &bytes.Buffer{}
   237  			setRequest(encodeRequest(headRequest))
   238  			str.EXPECT().Context().Return(reqContext)
   239  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   240  			str.EXPECT().CancelRead(gomock.Any())
   241  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   242  			Expect(serr.err).ToNot(HaveOccurred())
   243  			hfs := decodeHeader(responseBuf)
   244  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   245  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   246  		})
   247  
   248  		It("response to HEAD request should also do content sniffing", func() {
   249  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   250  				w.Write([]byte("<html></html>"))
   251  			})
   252  
   253  			headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil)
   254  			Expect(err).ToNot(HaveOccurred())
   255  			responseBuf := &bytes.Buffer{}
   256  			setRequest(encodeRequest(headRequest))
   257  			str.EXPECT().Context().Return(reqContext)
   258  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   259  			str.EXPECT().CancelRead(gomock.Any())
   260  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   261  			Expect(serr.err).ToNot(HaveOccurred())
   262  			hfs := decodeHeader(responseBuf)
   263  			Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   264  			Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"}))
   265  			Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"}))
   266  		})
   267  
   268  		It("handles a aborting handler", func() {
   269  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   270  				panic(http.ErrAbortHandler)
   271  			})
   272  
   273  			responseBuf := &bytes.Buffer{}
   274  			setRequest(encodeRequest(exampleGetRequest))
   275  			str.EXPECT().Context().Return(reqContext)
   276  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   277  			str.EXPECT().CancelRead(gomock.Any())
   278  
   279  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   280  			Expect(serr.err).To(MatchError(errPanicked))
   281  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   282  		})
   283  
   284  		It("handles a panicking handler", func() {
   285  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   286  				panic("foobar")
   287  			})
   288  
   289  			responseBuf := &bytes.Buffer{}
   290  			setRequest(encodeRequest(exampleGetRequest))
   291  			str.EXPECT().Context().Return(reqContext)
   292  			str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   293  			str.EXPECT().CancelRead(gomock.Any())
   294  
   295  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   296  			Expect(serr.err).To(MatchError(errPanicked))
   297  			Expect(responseBuf.Bytes()).To(HaveLen(0))
   298  		})
   299  
   300  		Context("hijacking bidirectional streams", func() {
   301  			var conn *mockquic.MockEarlyConnection
   302  			testDone := make(chan struct{})
   303  
   304  			BeforeEach(func() {
   305  				testDone = make(chan struct{})
   306  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   307  				controlStr := mockquic.NewMockStream(mockCtrl)
   308  				controlStr.EXPECT().Write(gomock.Any())
   309  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   310  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   311  				conn.EXPECT().LocalAddr().AnyTimes()
   312  			})
   313  
   314  			AfterEach(func() { testDone <- struct{}{} })
   315  
   316  			It("hijacks a bidirectional stream of unknown frame type", func() {
   317  				frameTypeChan := make(chan FrameType, 1)
   318  				s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   319  					Expect(e).ToNot(HaveOccurred())
   320  					frameTypeChan <- ft
   321  					return true, nil
   322  				}
   323  
   324  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   325  				unknownStr := mockquic.NewMockStream(mockCtrl)
   326  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   327  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   328  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   329  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   330  					<-testDone
   331  					return nil, errors.New("test done")
   332  				})
   333  				s.handleConn(conn)
   334  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   335  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   336  			})
   337  
   338  			It("cancels writing when hijacker didn't hijack a bidirectional stream", func() {
   339  				frameTypeChan := make(chan FrameType, 1)
   340  				s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   341  					Expect(e).ToNot(HaveOccurred())
   342  					frameTypeChan <- ft
   343  					return false, nil
   344  				}
   345  
   346  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   347  				unknownStr := mockquic.NewMockStream(mockCtrl)
   348  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   349  				unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   350  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   351  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   352  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   353  					<-testDone
   354  					return nil, errors.New("test done")
   355  				})
   356  				s.handleConn(conn)
   357  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   358  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   359  			})
   360  
   361  			It("cancels writing when hijacker returned error", func() {
   362  				frameTypeChan := make(chan FrameType, 1)
   363  				s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
   364  					Expect(e).ToNot(HaveOccurred())
   365  					frameTypeChan <- ft
   366  					return false, errors.New("error in hijacker")
   367  				}
   368  
   369  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
   370  				unknownStr := mockquic.NewMockStream(mockCtrl)
   371  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   372  				unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
   373  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   374  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   375  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   376  					<-testDone
   377  					return nil, errors.New("test done")
   378  				})
   379  				s.handleConn(conn)
   380  				Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
   381  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   382  			})
   383  
   384  			It("handles errors that occur when reading the stream type", func() {
   385  				testErr := errors.New("test error")
   386  				done := make(chan struct{})
   387  				unknownStr := mockquic.NewMockStream(mockCtrl)
   388  				s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) {
   389  					defer close(done)
   390  					Expect(ft).To(BeZero())
   391  					Expect(str).To(Equal(unknownStr))
   392  					Expect(err).To(MatchError(testErr))
   393  					return true, nil
   394  				}
   395  
   396  				unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
   397  				conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
   398  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   399  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   400  					<-testDone
   401  					return nil, errors.New("test done")
   402  				})
   403  				s.handleConn(conn)
   404  				Eventually(done).Should(BeClosed())
   405  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   406  			})
   407  		})
   408  
   409  		Context("hijacking unidirectional streams", func() {
   410  			var conn *mockquic.MockEarlyConnection
   411  			testDone := make(chan struct{})
   412  
   413  			BeforeEach(func() {
   414  				testDone = make(chan struct{})
   415  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   416  				controlStr := mockquic.NewMockStream(mockCtrl)
   417  				controlStr.EXPECT().Write(gomock.Any())
   418  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   419  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   420  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   421  				conn.EXPECT().LocalAddr().AnyTimes()
   422  			})
   423  
   424  			AfterEach(func() { testDone <- struct{}{} })
   425  
   426  			It("hijacks an unidirectional stream of unknown stream type", func() {
   427  				streamTypeChan := make(chan StreamType, 1)
   428  				s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   429  					Expect(err).ToNot(HaveOccurred())
   430  					streamTypeChan <- st
   431  					return true
   432  				}
   433  
   434  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   435  				unknownStr := mockquic.NewMockStream(mockCtrl)
   436  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   437  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   438  					return unknownStr, nil
   439  				})
   440  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   441  					<-testDone
   442  					return nil, errors.New("test done")
   443  				})
   444  				s.handleConn(conn)
   445  				Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   446  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   447  			})
   448  
   449  			It("handles errors that occur when reading the stream type", func() {
   450  				testErr := errors.New("test error")
   451  				done := make(chan struct{})
   452  				unknownStr := mockquic.NewMockStream(mockCtrl)
   453  				s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
   454  					defer close(done)
   455  					Expect(st).To(BeZero())
   456  					Expect(str).To(Equal(unknownStr))
   457  					Expect(err).To(MatchError(testErr))
   458  					return true
   459  				}
   460  
   461  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr })
   462  				conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
   463  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   464  					<-testDone
   465  					return nil, errors.New("test done")
   466  				})
   467  				s.handleConn(conn)
   468  				Eventually(done).Should(BeClosed())
   469  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   470  			})
   471  
   472  			It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
   473  				streamTypeChan := make(chan StreamType, 1)
   474  				s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
   475  					Expect(err).ToNot(HaveOccurred())
   476  					streamTypeChan <- st
   477  					return false
   478  				}
   479  
   480  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
   481  				unknownStr := mockquic.NewMockStream(mockCtrl)
   482  				unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   483  				unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   484  
   485  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   486  					return unknownStr, nil
   487  				})
   488  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   489  					<-testDone
   490  					return nil, errors.New("test done")
   491  				})
   492  				s.handleConn(conn)
   493  				Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
   494  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   495  			})
   496  		})
   497  
   498  		Context("control stream handling", func() {
   499  			var conn *mockquic.MockEarlyConnection
   500  			testDone := make(chan struct{})
   501  
   502  			BeforeEach(func() {
   503  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   504  				controlStr := mockquic.NewMockStream(mockCtrl)
   505  				controlStr.EXPECT().Write(gomock.Any())
   506  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   507  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   508  				conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
   509  				conn.EXPECT().LocalAddr().AnyTimes()
   510  			})
   511  
   512  			AfterEach(func() { testDone <- struct{}{} })
   513  
   514  			It("parses the SETTINGS frame", func() {
   515  				b := quicvarint.Append(nil, streamTypeControlStream)
   516  				b = (&settingsFrame{}).Append(b)
   517  				controlStr := mockquic.NewMockStream(mockCtrl)
   518  				r := bytes.NewReader(b)
   519  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   520  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   521  					return controlStr, nil
   522  				})
   523  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   524  					<-testDone
   525  					return nil, errors.New("test done")
   526  				})
   527  				s.handleConn(conn)
   528  				time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
   529  			})
   530  
   531  			for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
   532  				streamType := t
   533  				name := "encoder"
   534  				if streamType == streamTypeQPACKDecoderStream {
   535  					name = "decoder"
   536  				}
   537  
   538  				It(fmt.Sprintf("ignores the QPACK %s streams", name), func() {
   539  					buf := bytes.NewBuffer(quicvarint.Append(nil, streamType))
   540  					str := mockquic.NewMockStream(mockCtrl)
   541  					str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   542  
   543  					conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   544  						return str, nil
   545  					})
   546  					conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   547  						<-testDone
   548  						return nil, errors.New("test done")
   549  					})
   550  					s.handleConn(conn)
   551  					time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
   552  				})
   553  			}
   554  
   555  			It("reset streams other than the control stream and the QPACK streams", func() {
   556  				buf := bytes.NewBuffer(quicvarint.Append(nil, 0o1337))
   557  				str := mockquic.NewMockStream(mockCtrl)
   558  				str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
   559  				done := make(chan struct{})
   560  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) })
   561  
   562  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   563  					return str, nil
   564  				})
   565  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   566  					<-testDone
   567  					return nil, errors.New("test done")
   568  				})
   569  				s.handleConn(conn)
   570  				Eventually(done).Should(BeClosed())
   571  			})
   572  
   573  			It("errors when the first frame on the control stream is not a SETTINGS frame", func() {
   574  				b := quicvarint.Append(nil, streamTypeControlStream)
   575  				b = (&dataFrame{}).Append(b)
   576  				controlStr := mockquic.NewMockStream(mockCtrl)
   577  				r := bytes.NewReader(b)
   578  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   579  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   580  					return controlStr, nil
   581  				})
   582  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   583  					<-testDone
   584  					return nil, errors.New("test done")
   585  				})
   586  				done := make(chan struct{})
   587  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   588  					close(done)
   589  					return nil
   590  				})
   591  				s.handleConn(conn)
   592  				Eventually(done).Should(BeClosed())
   593  			})
   594  
   595  			It("errors when parsing the frame on the control stream fails", func() {
   596  				b := quicvarint.Append(nil, streamTypeControlStream)
   597  				b = (&settingsFrame{}).Append(b)
   598  				r := bytes.NewReader(b[:len(b)-1])
   599  				controlStr := mockquic.NewMockStream(mockCtrl)
   600  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   601  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   602  					return controlStr, nil
   603  				})
   604  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   605  					<-testDone
   606  					return nil, errors.New("test done")
   607  				})
   608  				done := make(chan struct{})
   609  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   610  					close(done)
   611  					return nil
   612  				})
   613  				s.handleConn(conn)
   614  				Eventually(done).Should(BeClosed())
   615  			})
   616  
   617  			It("errors when the client opens a push stream", func() {
   618  				b := quicvarint.Append(nil, streamTypePushStream)
   619  				b = (&dataFrame{}).Append(b)
   620  				r := bytes.NewReader(b)
   621  				controlStr := mockquic.NewMockStream(mockCtrl)
   622  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   623  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   624  					return controlStr, nil
   625  				})
   626  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   627  					<-testDone
   628  					return nil, errors.New("test done")
   629  				})
   630  				done := make(chan struct{})
   631  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   632  					close(done)
   633  					return nil
   634  				})
   635  				s.handleConn(conn)
   636  				Eventually(done).Should(BeClosed())
   637  			})
   638  
   639  			It("errors when the client advertises datagram support (and we enabled support for it)", func() {
   640  				s.EnableDatagrams = true
   641  				b := quicvarint.Append(nil, streamTypeControlStream)
   642  				b = (&settingsFrame{Datagram: true}).Append(b)
   643  				r := bytes.NewReader(b)
   644  				controlStr := mockquic.NewMockStream(mockCtrl)
   645  				controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
   646  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   647  					return controlStr, nil
   648  				})
   649  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   650  					<-testDone
   651  					return nil, errors.New("test done")
   652  				})
   653  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
   654  				done := make(chan struct{})
   655  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error {
   656  					close(done)
   657  					return nil
   658  				})
   659  				s.handleConn(conn)
   660  				Eventually(done).Should(BeClosed())
   661  			})
   662  		})
   663  
   664  		Context("stream- and connection-level errors", func() {
   665  			var conn *mockquic.MockEarlyConnection
   666  			testDone := make(chan struct{})
   667  
   668  			BeforeEach(func() {
   669  				testDone = make(chan struct{})
   670  				addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
   671  				conn = mockquic.NewMockEarlyConnection(mockCtrl)
   672  				controlStr := mockquic.NewMockStream(mockCtrl)
   673  				controlStr.EXPECT().Write(gomock.Any())
   674  				conn.EXPECT().OpenUniStream().Return(controlStr, nil)
   675  				conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
   676  					<-testDone
   677  					return nil, errors.New("test done")
   678  				})
   679  				conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
   680  				conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
   681  				conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
   682  				conn.EXPECT().LocalAddr().AnyTimes()
   683  				conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
   684  			})
   685  
   686  			AfterEach(func() { testDone <- struct{}{} })
   687  
   688  			It("cancels reading when client sends a body in GET request", func() {
   689  				var handlerCalled bool
   690  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   691  					handlerCalled = true
   692  				})
   693  
   694  				requestData := encodeRequest(exampleGetRequest)
   695  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   696  				b = append(b, []byte("foobar")...)
   697  				responseBuf := &bytes.Buffer{}
   698  				setRequest(append(requestData, b...))
   699  				done := make(chan struct{})
   700  				str.EXPECT().Context().Return(reqContext)
   701  				str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   702  				str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   703  				str.EXPECT().Close().Do(func() error { close(done); return nil })
   704  
   705  				s.handleConn(conn)
   706  				Eventually(done).Should(BeClosed())
   707  				hfs := decodeHeader(responseBuf)
   708  				Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
   709  				Expect(handlerCalled).To(BeTrue())
   710  			})
   711  
   712  			It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() {
   713  				handlerCalled := make(chan struct{})
   714  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   715  					defer close(handlerCalled)
   716  					r.Body.(HTTPStreamer).HTTPStream()
   717  					str.Write([]byte("foobar"))
   718  				})
   719  
   720  				requestData := encodeRequest(exampleGetRequest)
   721  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   722  				b = append(b, []byte("foobar")...)
   723  				setRequest(append(requestData, b...))
   724  				str.EXPECT().Context().Return(reqContext)
   725  				str.EXPECT().Write([]byte("foobar")).Return(6, nil)
   726  
   727  				s.handleConn(conn)
   728  				Eventually(handlerCalled).Should(BeClosed())
   729  			})
   730  
   731  			It("errors when the client sends a too large header frame", func() {
   732  				s.MaxHeaderBytes = 20
   733  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   734  					Fail("Handler should not be called.")
   735  				})
   736  
   737  				requestData := encodeRequest(exampleGetRequest)
   738  				b := (&dataFrame{Length: 6}).Append(nil) // add a body
   739  				b = append(b, []byte("foobar")...)
   740  				responseBuf := &bytes.Buffer{}
   741  				setRequest(append(requestData, b...))
   742  				done := make(chan struct{})
   743  				str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
   744  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
   745  
   746  				s.handleConn(conn)
   747  				Eventually(done).Should(BeClosed())
   748  			})
   749  
   750  			It("handles a request for which the client immediately resets the stream", func() {
   751  				handlerCalled := make(chan struct{})
   752  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   753  					close(handlerCalled)
   754  				})
   755  
   756  				testErr := errors.New("stream reset")
   757  				done := make(chan struct{})
   758  				str.EXPECT().Read(gomock.Any()).Return(0, testErr)
   759  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
   760  
   761  				s.handleConn(conn)
   762  				Consistently(handlerCalled).ShouldNot(BeClosed())
   763  			})
   764  
   765  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   766  				handlerCalled := make(chan struct{})
   767  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   768  					close(handlerCalled)
   769  				})
   770  
   771  				b := (&dataFrame{}).Append(nil)
   772  				setRequest(b)
   773  				str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   774  					return len(p), nil
   775  				}).AnyTimes()
   776  
   777  				done := make(chan struct{})
   778  				conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
   779  					close(done)
   780  					return nil
   781  				})
   782  				s.handleConn(conn)
   783  				Eventually(done).Should(BeClosed())
   784  			})
   785  
   786  			It("closes the connection when the first frame is not a HEADERS frame", func() {
   787  				handlerCalled := make(chan struct{})
   788  				s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   789  					close(handlerCalled)
   790  				})
   791  
   792  				// use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest,
   793  				// but the request will still end up larger than DefaultMaxHeaderBytes.
   794  				url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2)
   795  				req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil)
   796  				Expect(err).ToNot(HaveOccurred())
   797  				setRequest(encodeRequest(req))
   798  				// str.EXPECT().Context().Return(reqContext)
   799  				str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   800  					return len(p), nil
   801  				}).AnyTimes()
   802  				done := make(chan struct{})
   803  				str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
   804  
   805  				s.handleConn(conn)
   806  				Eventually(done).Should(BeClosed())
   807  			})
   808  		})
   809  
   810  		It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
   811  			handlerCalled := make(chan struct{})
   812  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   813  				r.Body = struct {
   814  					io.Reader
   815  					io.Closer
   816  				}{}
   817  				close(handlerCalled)
   818  			})
   819  
   820  			setRequest(encodeRequest(examplePostRequest))
   821  			str.EXPECT().Context().Return(reqContext)
   822  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   823  				return len(p), nil
   824  			}).AnyTimes()
   825  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   826  
   827  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   828  			Expect(serr.err).ToNot(HaveOccurred())
   829  			Eventually(handlerCalled).Should(BeClosed())
   830  		})
   831  
   832  		It("cancels the request context when the stream is closed", func() {
   833  			handlerCalled := make(chan struct{})
   834  			s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   835  				defer GinkgoRecover()
   836  				Expect(r.Context().Done()).To(BeClosed())
   837  				Expect(r.Context().Err()).To(MatchError(context.Canceled))
   838  				close(handlerCalled)
   839  			})
   840  			setRequest(encodeRequest(examplePostRequest))
   841  
   842  			reqContext, cancel := context.WithCancel(context.Background())
   843  			cancel()
   844  			str.EXPECT().Context().Return(reqContext)
   845  			str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
   846  				return len(p), nil
   847  			}).AnyTimes()
   848  			str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
   849  
   850  			serr := s.handleRequest(conn, str, qpackDecoder, nil)
   851  			Expect(serr.err).ToNot(HaveOccurred())
   852  			Eventually(handlerCalled).Should(BeClosed())
   853  		})
   854  	})
   855  
   856  	Context("setting http headers", func() {
   857  		BeforeEach(func() {
   858  			s.QuicConfig = &quic.Config{Versions: []protocol.Version{protocol.Version1}}
   859  		})
   860  
   861  		var ln1 QUICEarlyListener
   862  		var ln2 QUICEarlyListener
   863  		expected := http.Header{
   864  			"Alt-Svc": {`h3=":443"; ma=2592000`},
   865  		}
   866  
   867  		addListener := func(addr string, ln *QUICEarlyListener) {
   868  			mln := newMockAddrListener(addr)
   869  			mln.EXPECT().Addr()
   870  			*ln = mln
   871  			s.addListener(ln)
   872  		}
   873  
   874  		removeListener := func(ln *QUICEarlyListener) {
   875  			s.removeListener(ln)
   876  		}
   877  
   878  		checkSetHeaders := func(expected gmtypes.GomegaMatcher) {
   879  			hdr := http.Header{}
   880  			Expect(s.SetQuicHeaders(hdr)).To(Succeed())
   881  			Expect(hdr).To(expected)
   882  		}
   883  
   884  		checkSetHeaderError := func() {
   885  			hdr := http.Header{}
   886  			Expect(s.SetQuicHeaders(hdr)).To(Equal(ErrNoAltSvcPort))
   887  		}
   888  
   889  		It("sets proper headers with numeric port", func() {
   890  			addListener(":443", &ln1)
   891  			checkSetHeaders(Equal(expected))
   892  			removeListener(&ln1)
   893  			checkSetHeaderError()
   894  		})
   895  
   896  		It("sets proper headers with full addr", func() {
   897  			addListener("127.0.0.1:443", &ln1)
   898  			checkSetHeaders(Equal(expected))
   899  			removeListener(&ln1)
   900  			checkSetHeaderError()
   901  		})
   902  
   903  		It("sets proper headers with string port", func() {
   904  			addListener(":https", &ln1)
   905  			checkSetHeaders(Equal(expected))
   906  			removeListener(&ln1)
   907  			checkSetHeaderError()
   908  		})
   909  
   910  		It("works multiple times", func() {
   911  			addListener(":https", &ln1)
   912  			checkSetHeaders(Equal(expected))
   913  			checkSetHeaders(Equal(expected))
   914  			removeListener(&ln1)
   915  			checkSetHeaderError()
   916  		})
   917  
   918  		It("works if the quic.Config sets QUIC versions", func() {
   919  			s.QuicConfig.Versions = []quic.Version{quic.Version1, quic.Version2}
   920  			addListener(":443", &ln1)
   921  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}}))
   922  			removeListener(&ln1)
   923  			checkSetHeaderError()
   924  		})
   925  
   926  		It("uses s.Port if set to a non-zero value", func() {
   927  			s.Port = 8443
   928  			addListener(":443", &ln1)
   929  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}}))
   930  			removeListener(&ln1)
   931  			checkSetHeaderError()
   932  		})
   933  
   934  		It("uses s.Addr if listeners don't have ports available", func() {
   935  			s.Addr = ":443"
   936  			mln := &noPortListener{newMockAddrListener("")}
   937  			mln.EXPECT().Addr()
   938  			ln1 = mln
   939  			s.addListener(&ln1)
   940  			checkSetHeaders(Equal(expected))
   941  			s.removeListener(&ln1)
   942  			checkSetHeaderError()
   943  		})
   944  
   945  		It("properly announces multiple listeners", func() {
   946  			addListener(":443", &ln1)
   947  			addListener(":8443", &ln2)
   948  			checkSetHeaders(Or(
   949  				Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}),
   950  				Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}),
   951  			))
   952  			removeListener(&ln1)
   953  			removeListener(&ln2)
   954  			checkSetHeaderError()
   955  		})
   956  
   957  		It("doesn't duplicate Alt-Svc values", func() {
   958  			s.QuicConfig.Versions = []quic.Version{quic.Version1, quic.Version1}
   959  			addListener(":443", &ln1)
   960  			checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}}))
   961  			removeListener(&ln1)
   962  			checkSetHeaderError()
   963  		})
   964  	})
   965  
   966  	It("errors when ListenAndServe is called with s.TLSConfig nil", func() {
   967  		Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig))
   968  	})
   969  
   970  	It("should nop-Close() when s.server is nil", func() {
   971  		Expect((&Server{}).Close()).To(Succeed())
   972  	})
   973  
   974  	It("errors when ListenAndServeTLS is called after Close", func() {
   975  		serv := &Server{}
   976  		Expect(serv.Close()).To(Succeed())
   977  		Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed))
   978  	})
   979  
   980  	It("handles concurrent Serve and Close", func() {
   981  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   982  		Expect(err).ToNot(HaveOccurred())
   983  		c, err := net.ListenUDP("udp", addr)
   984  		Expect(err).ToNot(HaveOccurred())
   985  		done := make(chan struct{})
   986  		go func() {
   987  			defer GinkgoRecover()
   988  			defer close(done)
   989  			s.Serve(c)
   990  		}()
   991  		runtime.Gosched()
   992  		s.Close()
   993  		Eventually(done).Should(BeClosed())
   994  	})
   995  
   996  	Context("ConfigureTLSConfig", func() {
   997  		It("advertises v1 by default", func() {
   998  			conf := ConfigureTLSConfig(testdata.GetTLSConfig())
   999  			ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.Version{quic.Version1}})
  1000  			Expect(err).ToNot(HaveOccurred())
  1001  			defer ln.Close()
  1002  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
  1003  			Expect(err).ToNot(HaveOccurred())
  1004  			defer c.CloseWithError(0, "")
  1005  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
  1006  		})
  1007  
  1008  		It("sets the GetConfigForClient callback if no tls.Config is given", func() {
  1009  			var receivedConf *tls.Config
  1010  			quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) {
  1011  				receivedConf = tlsConf
  1012  				return nil, errors.New("listen err")
  1013  			}
  1014  			Expect(s.ListenAndServe()).To(HaveOccurred())
  1015  			Expect(receivedConf).ToNot(BeNil())
  1016  		})
  1017  
  1018  		It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() {
  1019  			tlsConf := &tls.Config{
  1020  				GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
  1021  					c := testdata.GetTLSConfig()
  1022  					c.NextProtos = []string{"foo", "bar"}
  1023  					return c, nil
  1024  				},
  1025  			}
  1026  
  1027  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
  1028  			Expect(err).ToNot(HaveOccurred())
  1029  			defer ln.Close()
  1030  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
  1031  			Expect(err).ToNot(HaveOccurred())
  1032  			defer c.CloseWithError(0, "")
  1033  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
  1034  		})
  1035  
  1036  		It("works if GetConfigForClient returns a nil tls.Config", func() {
  1037  			tlsConf := testdata.GetTLSConfig()
  1038  			tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }
  1039  
  1040  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
  1041  			Expect(err).ToNot(HaveOccurred())
  1042  			defer ln.Close()
  1043  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
  1044  			Expect(err).ToNot(HaveOccurred())
  1045  			defer c.CloseWithError(0, "")
  1046  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
  1047  		})
  1048  
  1049  		It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
  1050  			tlsClientConf := testdata.GetTLSConfig()
  1051  			tlsClientConf.NextProtos = []string{"foo", "bar"}
  1052  			tlsConf := &tls.Config{
  1053  				GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
  1054  					return tlsClientConf, nil
  1055  				},
  1056  			}
  1057  
  1058  			ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}})
  1059  			Expect(err).ToNot(HaveOccurred())
  1060  			defer ln.Close()
  1061  			c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
  1062  			Expect(err).ToNot(HaveOccurred())
  1063  			defer c.CloseWithError(0, "")
  1064  			Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
  1065  			// check that the original config was not modified
  1066  			Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"}))
  1067  		})
  1068  	})
  1069  
  1070  	Context("Serve", func() {
  1071  		origQuicListen := quicListen
  1072  
  1073  		AfterEach(func() {
  1074  			quicListen = origQuicListen
  1075  		})
  1076  
  1077  		It("serves a packet conn", func() {
  1078  			ln := newMockAddrListener(":443")
  1079  			conn := &net.UDPConn{}
  1080  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1081  				Expect(c).To(Equal(conn))
  1082  				return ln, nil
  1083  			}
  1084  
  1085  			s := &Server{
  1086  				TLSConfig: &tls.Config{},
  1087  			}
  1088  
  1089  			stopAccept := make(chan struct{})
  1090  			ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1091  				<-stopAccept
  1092  				return nil, errors.New("closed")
  1093  			})
  1094  			ln.EXPECT().Addr() // generate alt-svc headers
  1095  			done := make(chan struct{})
  1096  			go func() {
  1097  				defer GinkgoRecover()
  1098  				defer close(done)
  1099  				s.Serve(conn)
  1100  			}()
  1101  
  1102  			Consistently(done).ShouldNot(BeClosed())
  1103  			ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil })
  1104  			Expect(s.Close()).To(Succeed())
  1105  			Eventually(done).Should(BeClosed())
  1106  		})
  1107  
  1108  		It("serves two packet conns", func() {
  1109  			ln1 := newMockAddrListener(":443")
  1110  			ln2 := newMockAddrListener(":8443")
  1111  			lns := make(chan QUICEarlyListener, 2)
  1112  			lns <- ln1
  1113  			lns <- ln2
  1114  			conn1 := &net.UDPConn{}
  1115  			conn2 := &net.UDPConn{}
  1116  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1117  				return <-lns, nil
  1118  			}
  1119  
  1120  			s := &Server{
  1121  				TLSConfig: &tls.Config{},
  1122  			}
  1123  
  1124  			stopAccept1 := make(chan struct{})
  1125  			ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1126  				<-stopAccept1
  1127  				return nil, errors.New("closed")
  1128  			})
  1129  			ln1.EXPECT().Addr() // generate alt-svc headers
  1130  			stopAccept2 := make(chan struct{})
  1131  			ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1132  				<-stopAccept2
  1133  				return nil, errors.New("closed")
  1134  			})
  1135  			ln2.EXPECT().Addr()
  1136  
  1137  			done1 := make(chan struct{})
  1138  			go func() {
  1139  				defer GinkgoRecover()
  1140  				defer close(done1)
  1141  				s.Serve(conn1)
  1142  			}()
  1143  			done2 := make(chan struct{})
  1144  			go func() {
  1145  				defer GinkgoRecover()
  1146  				defer close(done2)
  1147  				s.Serve(conn2)
  1148  			}()
  1149  
  1150  			Consistently(done1).ShouldNot(BeClosed())
  1151  			Expect(done2).ToNot(BeClosed())
  1152  			ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil })
  1153  			ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil })
  1154  			Expect(s.Close()).To(Succeed())
  1155  			Eventually(done1).Should(BeClosed())
  1156  			Eventually(done2).Should(BeClosed())
  1157  		})
  1158  	})
  1159  
  1160  	Context("ServeListener", func() {
  1161  		origQuicListen := quicListen
  1162  
  1163  		AfterEach(func() {
  1164  			quicListen = origQuicListen
  1165  		})
  1166  
  1167  		It("serves a listener", func() {
  1168  			var called int32
  1169  			ln := newMockAddrListener(":443")
  1170  			quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1171  				atomic.StoreInt32(&called, 1)
  1172  				return ln, nil
  1173  			}
  1174  
  1175  			s := &Server{}
  1176  
  1177  			stopAccept := make(chan struct{})
  1178  			ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1179  				<-stopAccept
  1180  				return nil, errors.New("closed")
  1181  			})
  1182  			ln.EXPECT().Addr() // generate alt-svc headers
  1183  			done := make(chan struct{})
  1184  			go func() {
  1185  				defer GinkgoRecover()
  1186  				defer close(done)
  1187  				s.ServeListener(ln)
  1188  			}()
  1189  
  1190  			Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
  1191  			Consistently(done).ShouldNot(BeClosed())
  1192  			ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil })
  1193  			Expect(s.Close()).To(Succeed())
  1194  			Eventually(done).Should(BeClosed())
  1195  		})
  1196  
  1197  		It("serves two listeners", func() {
  1198  			var called int32
  1199  			ln1 := newMockAddrListener(":443")
  1200  			ln2 := newMockAddrListener(":8443")
  1201  			lns := make(chan QUICEarlyListener, 2)
  1202  			lns <- ln1
  1203  			lns <- ln2
  1204  			quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1205  				atomic.StoreInt32(&called, 1)
  1206  				return <-lns, nil
  1207  			}
  1208  
  1209  			s := &Server{}
  1210  
  1211  			stopAccept1 := make(chan struct{})
  1212  			ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1213  				<-stopAccept1
  1214  				return nil, errors.New("closed")
  1215  			})
  1216  			ln1.EXPECT().Addr() // generate alt-svc headers
  1217  			stopAccept2 := make(chan struct{})
  1218  			ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) {
  1219  				<-stopAccept2
  1220  				return nil, errors.New("closed")
  1221  			})
  1222  			ln2.EXPECT().Addr()
  1223  
  1224  			done1 := make(chan struct{})
  1225  			go func() {
  1226  				defer GinkgoRecover()
  1227  				defer close(done1)
  1228  				s.ServeListener(ln1)
  1229  			}()
  1230  			done2 := make(chan struct{})
  1231  			go func() {
  1232  				defer GinkgoRecover()
  1233  				defer close(done2)
  1234  				s.ServeListener(ln2)
  1235  			}()
  1236  
  1237  			Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
  1238  			Consistently(done1).ShouldNot(BeClosed())
  1239  			Expect(done2).ToNot(BeClosed())
  1240  			ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil })
  1241  			ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil })
  1242  			Expect(s.Close()).To(Succeed())
  1243  			Eventually(done1).Should(BeClosed())
  1244  			Eventually(done2).Should(BeClosed())
  1245  		})
  1246  	})
  1247  
  1248  	Context("ServeQUICConn", func() {
  1249  		It("serves a QUIC connection", func() {
  1250  			mux := http.NewServeMux()
  1251  			mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) {
  1252  				w.Write([]byte("foobar"))
  1253  			})
  1254  			s.Handler = mux
  1255  			tlsConf := testdata.GetTLSConfig()
  1256  			tlsConf.NextProtos = []string{NextProtoH3}
  1257  			conn := mockquic.NewMockEarlyConnection(mockCtrl)
  1258  			controlStr := mockquic.NewMockStream(mockCtrl)
  1259  			controlStr.EXPECT().Write(gomock.Any())
  1260  			conn.EXPECT().OpenUniStream().Return(controlStr, nil)
  1261  			testDone := make(chan struct{})
  1262  			conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
  1263  				<-testDone
  1264  				return nil, errors.New("test done")
  1265  			}).MaxTimes(1)
  1266  			conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)})
  1267  			s.ServeQUICConn(conn)
  1268  			close(testDone)
  1269  		})
  1270  	})
  1271  
  1272  	Context("ListenAndServe", func() {
  1273  		BeforeEach(func() {
  1274  			s.Addr = "localhost:0"
  1275  		})
  1276  
  1277  		AfterEach(func() {
  1278  			Expect(s.Close()).To(Succeed())
  1279  		})
  1280  
  1281  		It("uses the quic.Config to start the QUIC server", func() {
  1282  			conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond}
  1283  			var receivedConf *quic.Config
  1284  			quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1285  				receivedConf = config
  1286  				return nil, errors.New("listen err")
  1287  			}
  1288  			s.QuicConfig = conf
  1289  			Expect(s.ListenAndServe()).To(HaveOccurred())
  1290  			Expect(receivedConf).To(Equal(conf))
  1291  		})
  1292  	})
  1293  
  1294  	It("closes gracefully", func() {
  1295  		Expect(s.CloseGracefully(0)).To(Succeed())
  1296  	})
  1297  
  1298  	It("errors when listening fails", func() {
  1299  		testErr := errors.New("listen error")
  1300  		quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1301  			return nil, testErr
  1302  		}
  1303  		fullpem, privkey := testdata.GetCertificatePaths()
  1304  		Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr))
  1305  	})
  1306  
  1307  	It("supports H3_DATAGRAM", func() {
  1308  		s.EnableDatagrams = true
  1309  		var receivedConf *quic.Config
  1310  		quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
  1311  			receivedConf = config
  1312  			return nil, errors.New("listen err")
  1313  		}
  1314  		Expect(s.ListenAndServe()).To(HaveOccurred())
  1315  		Expect(receivedConf.EnableDatagrams).To(BeTrue())
  1316  	})
  1317  })