github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/h2quic/server.go (about)

     1  package h2quic
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"runtime"
    10  	"strings"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	quic "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go"
    16  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
    17  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
    18  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/qerr"
    19  	"golang.org/x/net/http2"
    20  	"golang.org/x/net/http2/hpack"
    21  )
    22  
    23  type streamCreator interface {
    24  	quic.Session
    25  	GetOrOpenStream(protocol.StreamID) (quic.Stream, error)
    26  }
    27  
    28  type remoteCloser interface {
    29  	CloseRemote(protocol.ByteCount)
    30  }
    31  
    32  // allows mocking of quic.Listen and quic.ListenAddr
    33  var (
    34  	quicListen     = quic.Listen
    35  	quicListenAddr = quic.ListenAddr
    36  )
    37  
    38  // Server is a HTTP2 server listening for QUIC connections.
    39  type Server struct {
    40  	*http.Server
    41  
    42  	// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
    43  	// If nil, it uses reasonable default values.
    44  	QuicConfig *quic.Config
    45  
    46  	// Private flag for demo, do not use
    47  	CloseAfterFirstRequest bool
    48  
    49  	port uint32 // used atomically
    50  
    51  	listenerMutex sync.Mutex
    52  	listener      quic.Listener
    53  	closed        bool
    54  
    55  	supportedVersionsAsString string
    56  
    57  	logger utils.Logger // will be set by Server.serveImpl()
    58  }
    59  
    60  // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
    61  func (s *Server) ListenAndServe() error {
    62  	if s.Server == nil {
    63  		return errors.New("use of h2quic.Server without http.Server")
    64  	}
    65  	return s.serveImpl(s.TLSConfig, nil)
    66  }
    67  
    68  // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
    69  func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
    70  	var err error
    71  	certs := make([]tls.Certificate, 1)
    72  	certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
    73  	if err != nil {
    74  		return err
    75  	}
    76  	// We currently only use the cert-related stuff from tls.Config,
    77  	// so we don't need to make a full copy.
    78  	config := &tls.Config{
    79  		Certificates: certs,
    80  	}
    81  	return s.serveImpl(config, nil)
    82  }
    83  
    84  // Serve an existing UDP connection.
    85  func (s *Server) Serve(conn net.PacketConn) error {
    86  	return s.serveImpl(s.TLSConfig, conn)
    87  }
    88  
    89  func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
    90  	if s.Server == nil {
    91  		return errors.New("use of h2quic.Server without http.Server")
    92  	}
    93  	s.logger = utils.DefaultLogger.WithPrefix("server")
    94  	s.listenerMutex.Lock()
    95  	if s.closed {
    96  		s.listenerMutex.Unlock()
    97  		return errors.New("Server is already closed")
    98  	}
    99  	if s.listener != nil {
   100  		s.listenerMutex.Unlock()
   101  		return errors.New("ListenAndServe may only be called once")
   102  	}
   103  
   104  	var ln quic.Listener
   105  	var err error
   106  	if conn == nil {
   107  		ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
   108  	} else {
   109  		ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
   110  	}
   111  	if err != nil {
   112  		s.listenerMutex.Unlock()
   113  		return err
   114  	}
   115  	s.listener = ln
   116  	s.listenerMutex.Unlock()
   117  
   118  	for {
   119  		sess, err := ln.Accept()
   120  		if err != nil {
   121  			return err
   122  		}
   123  		go s.handleHeaderStream(sess.(streamCreator))
   124  	}
   125  }
   126  
   127  func (s *Server) handleHeaderStream(session streamCreator) {
   128  	stream, err := session.AcceptStream()
   129  	if err != nil {
   130  		session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err)
   131  		return
   132  	}
   133  
   134  	hpackDecoder := hpack.NewDecoder(4096, nil)
   135  	h2framer := http2.NewFramer(nil, stream)
   136  
   137  	var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
   138  	for {
   139  		if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
   140  			// QuicErrors must originate from stream.Read() returning an error.
   141  			// In this case, the session has already logged the error, so we don't
   142  			// need to log it again.
   143  			errorCode := qerr.InternalError
   144  			if qerr, ok := err.(*qerr.QuicError); ok {
   145  				errorCode = qerr.ErrorCode
   146  				s.logger.Errorf("error handling h2 request: %s", err.Error())
   147  			}
   148  			session.CloseWithError(quic.ErrorCode(errorCode), err)
   149  			return
   150  		}
   151  	}
   152  }
   153  
   154  func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
   155  	h2frame, err := h2framer.ReadFrame()
   156  	if err != nil {
   157  		return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
   158  	}
   159  	var h2headersFrame *http2.HeadersFrame
   160  	switch f := h2frame.(type) {
   161  	case *http2.PriorityFrame:
   162  		// ignore PRIORITY frames
   163  		s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f)
   164  		return nil
   165  	case *http2.HeadersFrame:
   166  		h2headersFrame = f
   167  	default:
   168  		return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame")
   169  	}
   170  
   171  	if !h2headersFrame.HeadersEnded() {
   172  		return errors.New("http2 header continuation not implemented")
   173  	}
   174  	headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
   175  	if err != nil {
   176  		s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
   177  		return err
   178  	}
   179  
   180  	req, err := requestFromHeaders(headers)
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	if s.logger.Debug() {
   186  		s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
   187  	} else {
   188  		s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
   189  	}
   190  
   191  	dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
   192  	if err != nil {
   193  		return err
   194  	}
   195  	// this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request
   196  	if dataStream == nil {
   197  		return nil
   198  	}
   199  
   200  	// handleRequest should be as non-blocking as possible to minimize
   201  	// head-of-line blocking. Potentially blocking code is run in a separate
   202  	// goroutine, enabling handleRequest to return before the code is executed.
   203  	go func() {
   204  		streamEnded := h2headersFrame.StreamEnded()
   205  		if streamEnded {
   206  			dataStream.(remoteCloser).CloseRemote(0)
   207  			streamEnded = true
   208  			_, _ = dataStream.Read([]byte{0}) // read the eof
   209  		}
   210  
   211  		req = req.WithContext(dataStream.Context())
   212  		reqBody := newRequestBody(dataStream)
   213  		req.Body = reqBody
   214  
   215  		req.RemoteAddr = session.RemoteAddr().String()
   216  
   217  		responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
   218  
   219  		handler := s.Handler
   220  		if handler == nil {
   221  			handler = http.DefaultServeMux
   222  		}
   223  		panicked := false
   224  		func() {
   225  			defer func() {
   226  				if p := recover(); p != nil {
   227  					// Copied from net/http/server.go
   228  					const size = 64 << 10
   229  					buf := make([]byte, size)
   230  					buf = buf[:runtime.Stack(buf, false)]
   231  					s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
   232  					panicked = true
   233  				}
   234  			}()
   235  			handler.ServeHTTP(responseWriter, req)
   236  		}()
   237  		if panicked {
   238  			responseWriter.WriteHeader(500)
   239  		} else {
   240  			responseWriter.WriteHeader(200)
   241  		}
   242  		if responseWriter.dataStream != nil {
   243  			if !streamEnded && !reqBody.requestRead {
   244  				// in gQUIC, the error code doesn't matter, so just use 0 here
   245  				responseWriter.dataStream.CancelRead(0)
   246  			}
   247  			responseWriter.dataStream.Close()
   248  		}
   249  		if s.CloseAfterFirstRequest {
   250  			time.Sleep(100 * time.Millisecond)
   251  			session.Close()
   252  		}
   253  	}()
   254  
   255  	return nil
   256  }
   257  
   258  // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
   259  // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
   260  func (s *Server) Close() error {
   261  	s.listenerMutex.Lock()
   262  	defer s.listenerMutex.Unlock()
   263  	s.closed = true
   264  	if s.listener != nil {
   265  		err := s.listener.Close()
   266  		s.listener = nil
   267  		return err
   268  	}
   269  	return nil
   270  }
   271  
   272  // CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
   273  // CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
   274  func (s *Server) CloseGracefully(timeout time.Duration) error {
   275  	// TODO: implement
   276  	return nil
   277  }
   278  
   279  // SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
   280  // The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
   281  //  Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
   282  func (s *Server) SetQuicHeaders(hdr http.Header) error {
   283  	port := atomic.LoadUint32(&s.port)
   284  
   285  	if port == 0 {
   286  		// Extract port from s.Server.Addr
   287  		_, portStr, err := net.SplitHostPort(s.Server.Addr)
   288  		if err != nil {
   289  			return err
   290  		}
   291  		portInt, err := net.LookupPort("tcp", portStr)
   292  		if err != nil {
   293  			return err
   294  		}
   295  		port = uint32(portInt)
   296  		atomic.StoreUint32(&s.port, port)
   297  	}
   298  
   299  	if s.supportedVersionsAsString == "" {
   300  		var versions []string
   301  		for _, v := range protocol.SupportedVersions {
   302  			versions = append(versions, v.ToAltSvc())
   303  		}
   304  		s.supportedVersionsAsString = strings.Join(versions, ",")
   305  	}
   306  
   307  	hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
   308  
   309  	return nil
   310  }
   311  
   312  // ListenAndServeQUIC listens on the UDP network address addr and calls the
   313  // handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is
   314  // used when handler is nil.
   315  func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
   316  	server := &Server{
   317  		Server: &http.Server{
   318  			Addr:    addr,
   319  			Handler: handler,
   320  		},
   321  	}
   322  	return server.ListenAndServeTLS(certFile, keyFile)
   323  }
   324  
   325  // ListenAndServe listens on the given network address for both, TLS and QUIC
   326  // connetions in parallel. It returns if one of the two returns an error.
   327  // http.DefaultServeMux is used when handler is nil.
   328  // The correct Alt-Svc headers for QUIC are set.
   329  func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
   330  	// Load certs
   331  	var err error
   332  	certs := make([]tls.Certificate, 1)
   333  	certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
   334  	if err != nil {
   335  		return err
   336  	}
   337  	// We currently only use the cert-related stuff from tls.Config,
   338  	// so we don't need to make a full copy.
   339  	config := &tls.Config{
   340  		Certificates: certs,
   341  	}
   342  
   343  	// Open the listeners
   344  	udpAddr, err := net.ResolveUDPAddr("udp", addr)
   345  	if err != nil {
   346  		return err
   347  	}
   348  	udpConn, err := net.ListenUDP("udp", udpAddr)
   349  	if err != nil {
   350  		return err
   351  	}
   352  	defer udpConn.Close()
   353  
   354  	tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
   355  	if err != nil {
   356  		return err
   357  	}
   358  	tcpConn, err := net.ListenTCP("tcp", tcpAddr)
   359  	if err != nil {
   360  		return err
   361  	}
   362  	defer tcpConn.Close()
   363  
   364  	tlsConn := tls.NewListener(tcpConn, config)
   365  	defer tlsConn.Close()
   366  
   367  	// Start the servers
   368  	httpServer := &http.Server{
   369  		Addr:      addr,
   370  		TLSConfig: config,
   371  	}
   372  
   373  	quicServer := &Server{
   374  		Server: httpServer,
   375  	}
   376  
   377  	if handler == nil {
   378  		handler = http.DefaultServeMux
   379  	}
   380  	httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   381  		quicServer.SetQuicHeaders(w.Header())
   382  		handler.ServeHTTP(w, r)
   383  	})
   384  
   385  	hErr := make(chan error)
   386  	qErr := make(chan error)
   387  	go func() {
   388  		hErr <- httpServer.Serve(tlsConn)
   389  	}()
   390  	go func() {
   391  		qErr <- quicServer.Serve(udpConn)
   392  	}()
   393  
   394  	select {
   395  	case err := <-hErr:
   396  		quicServer.Close()
   397  		return err
   398  	case err := <-qErr:
   399  		// Cannot close the HTTP server or wait for requests to complete properly :/
   400  		return err
   401  	}
   402  }