github.com/Psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/h2quic/client.go (about)

     1  package h2quic
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"strings"
    11  	"sync"
    12  
    13  	"golang.org/x/net/http2"
    14  	"golang.org/x/net/http2/hpack"
    15  	"golang.org/x/net/idna"
    16  
    17  	quic "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go"
    18  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
    19  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
    20  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/qerr"
    21  )
    22  
    23  type roundTripperOpts struct {
    24  	DisableCompression bool
    25  }
    26  
    27  var dialAddr = quic.DialAddr
    28  
    29  // client is a HTTP2 client doing QUIC requests
    30  type client struct {
    31  	mutex sync.RWMutex
    32  
    33  	tlsConf *tls.Config
    34  	config  *quic.Config
    35  	opts    *roundTripperOpts
    36  
    37  	hostname     string
    38  	handshakeErr error
    39  	dialOnce     sync.Once
    40  	dialer       func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
    41  
    42  	// [Psiphon]
    43  	// Fix close-while-dialing race condition by synchronizing access to
    44  	// client.session and adding a closed flag to indicate if the client was
    45  	// closed while a dial was in progress.
    46  	sessionMutex sync.Mutex
    47  	closed       bool
    48  	session      quic.Session
    49  
    50  	headerStream  quic.Stream
    51  	headerErr     *qerr.QuicError
    52  	headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
    53  	requestWriter *requestWriter
    54  
    55  	responses map[protocol.StreamID]chan *http.Response
    56  
    57  	logger utils.Logger
    58  }
    59  
    60  var _ http.RoundTripper = &client{}
    61  
    62  var defaultQuicConfig = &quic.Config{
    63  	RequestConnectionIDOmission: true,
    64  	KeepAlive:                   true,
    65  }
    66  
    67  // newClient creates a new client
    68  func newClient(
    69  	hostname string,
    70  	tlsConfig *tls.Config,
    71  	opts *roundTripperOpts,
    72  	quicConfig *quic.Config,
    73  	dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
    74  ) *client {
    75  	config := defaultQuicConfig
    76  	if quicConfig != nil {
    77  		config = quicConfig
    78  	}
    79  	return &client{
    80  		hostname:      authorityAddr("https", hostname),
    81  		responses:     make(map[protocol.StreamID]chan *http.Response),
    82  		tlsConf:       tlsConfig,
    83  		config:        config,
    84  		opts:          opts,
    85  		headerErrored: make(chan struct{}),
    86  		dialer:        dialer,
    87  		logger:        utils.DefaultLogger.WithPrefix("client"),
    88  	}
    89  }
    90  
    91  // dial dials the connection
    92  func (c *client) dial() error {
    93  	var err error
    94  
    95  	// [Psiphon]
    96  	var session quic.Session
    97  
    98  	if c.dialer != nil {
    99  		session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
   100  	} else {
   101  		session, err = dialAddr(c.hostname, c.tlsConf, c.config)
   102  	}
   103  	if err != nil {
   104  		return err
   105  	}
   106  
   107  	// [Psiphon]
   108  	// Only this write and the Close reads of c.session require synchronization.
   109  	// After this point, it's safe to concurrently read c.session as it is not
   110  	// rewritten.
   111  	c.sessionMutex.Lock()
   112  	closed := c.closed
   113  	if !closed {
   114  		c.session = session
   115  	}
   116  	c.sessionMutex.Unlock()
   117  	if closed {
   118  		session.Close()
   119  		return errors.New("closed while dialing")
   120  	}
   121  	// [Psiphon]
   122  
   123  	// once the version has been negotiated, open the header stream
   124  	c.headerStream, err = c.session.OpenStream()
   125  	if err != nil {
   126  		return err
   127  	}
   128  	c.requestWriter = newRequestWriter(c.headerStream, c.logger)
   129  	go c.handleHeaderStream()
   130  	return nil
   131  }
   132  
   133  func (c *client) handleHeaderStream() {
   134  	decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
   135  	h2framer := http2.NewFramer(nil, c.headerStream)
   136  
   137  	var err error
   138  	for err == nil {
   139  		err = c.readResponse(h2framer, decoder)
   140  	}
   141  	if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
   142  		c.logger.Debugf("Error handling header stream: %s", err)
   143  	}
   144  	c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
   145  	// stop all running request
   146  	close(c.headerErrored)
   147  }
   148  
   149  func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
   150  	frame, err := h2framer.ReadFrame()
   151  	if err != nil {
   152  		return err
   153  	}
   154  	hframe, ok := frame.(*http2.HeadersFrame)
   155  	if !ok {
   156  		return errors.New("not a headers frame")
   157  	}
   158  	mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
   159  	mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
   160  	if err != nil {
   161  		return fmt.Errorf("cannot read header fields: %s", err.Error())
   162  	}
   163  
   164  	c.mutex.RLock()
   165  	responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
   166  	c.mutex.RUnlock()
   167  	if !ok {
   168  		return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
   169  	}
   170  
   171  	rsp, err := responseFromHeaders(mhframe)
   172  	if err != nil {
   173  		return err
   174  	}
   175  	responseChan <- rsp
   176  	return nil
   177  }
   178  
   179  // Roundtrip executes a request and returns a response
   180  func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
   181  	// TODO: add port to address, if it doesn't have one
   182  	if req.URL.Scheme != "https" {
   183  		return nil, errors.New("quic http2: unsupported scheme")
   184  	}
   185  	if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
   186  		return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
   187  	}
   188  
   189  	c.dialOnce.Do(func() {
   190  		c.handshakeErr = c.dial()
   191  	})
   192  
   193  	if c.handshakeErr != nil {
   194  		return nil, c.handshakeErr
   195  	}
   196  
   197  	hasBody := (req.Body != nil)
   198  
   199  	responseChan := make(chan *http.Response)
   200  	dataStream, err := c.session.OpenStreamSync()
   201  	if err != nil {
   202  		_ = c.closeWithError(err)
   203  		return nil, err
   204  	}
   205  	c.mutex.Lock()
   206  	c.responses[dataStream.StreamID()] = responseChan
   207  	c.mutex.Unlock()
   208  
   209  	var requestedGzip bool
   210  	if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
   211  		requestedGzip = true
   212  	}
   213  	// TODO: add support for trailers
   214  	endStream := !hasBody
   215  	err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
   216  	if err != nil {
   217  		_ = c.closeWithError(err)
   218  		return nil, err
   219  	}
   220  
   221  	resc := make(chan error, 1)
   222  	if hasBody {
   223  		go func() {
   224  			resc <- c.writeRequestBody(dataStream, req.Body)
   225  		}()
   226  	}
   227  
   228  	var res *http.Response
   229  
   230  	var receivedResponse bool
   231  	var bodySent bool
   232  
   233  	if !hasBody {
   234  		bodySent = true
   235  	}
   236  
   237  	ctx := req.Context()
   238  	for !(bodySent && receivedResponse) {
   239  		select {
   240  		case res = <-responseChan:
   241  			receivedResponse = true
   242  			c.mutex.Lock()
   243  			delete(c.responses, dataStream.StreamID())
   244  			c.mutex.Unlock()
   245  		case err := <-resc:
   246  			bodySent = true
   247  			if err != nil {
   248  				return nil, err
   249  			}
   250  		case <-ctx.Done():
   251  			// error code 6 signals that stream was canceled
   252  			dataStream.CancelRead(6)
   253  			dataStream.CancelWrite(6)
   254  			c.mutex.Lock()
   255  			delete(c.responses, dataStream.StreamID())
   256  			c.mutex.Unlock()
   257  			return nil, ctx.Err()
   258  		case <-c.headerErrored:
   259  			// an error occurred on the header stream
   260  			_ = c.closeWithError(c.headerErr)
   261  			return nil, c.headerErr
   262  		}
   263  	}
   264  
   265  	// TODO: correctly set this variable
   266  	var streamEnded bool
   267  	isHead := (req.Method == "HEAD")
   268  
   269  	res = setLength(res, isHead, streamEnded)
   270  
   271  	if streamEnded || isHead {
   272  		res.Body = noBody
   273  	} else {
   274  		res.Body = dataStream
   275  		if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
   276  			res.Header.Del("Content-Encoding")
   277  			res.Header.Del("Content-Length")
   278  			res.ContentLength = -1
   279  			res.Body = &gzipReader{body: res.Body}
   280  			res.Uncompressed = true
   281  		}
   282  	}
   283  
   284  	res.Request = req
   285  	return res, nil
   286  }
   287  
   288  func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
   289  	defer func() {
   290  		cerr := body.Close()
   291  		if err == nil {
   292  			// TODO: what to do with dataStream here? Maybe reset it?
   293  			err = cerr
   294  		}
   295  	}()
   296  
   297  	_, err = io.Copy(dataStream, body)
   298  	if err != nil {
   299  		// TODO: what to do with dataStream here? Maybe reset it?
   300  		return err
   301  	}
   302  	return dataStream.Close()
   303  }
   304  
   305  func (c *client) closeWithError(e error) error {
   306  
   307  	// [Psiphon]
   308  	c.sessionMutex.Lock()
   309  	session := c.session
   310  	c.closed = true
   311  	c.sessionMutex.Unlock()
   312  	// [Psiphon]
   313  
   314  	if session == nil {
   315  		return nil
   316  	}
   317  	return session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
   318  }
   319  
   320  // Close closes the client
   321  func (c *client) Close() error {
   322  
   323  	// [Psiphon]
   324  	c.sessionMutex.Lock()
   325  	session := c.session
   326  	c.closed = true
   327  	c.sessionMutex.Unlock()
   328  	// [Psiphon]
   329  
   330  	if session == nil {
   331  		return nil
   332  	}
   333  	return session.Close()
   334  }
   335  
   336  // copied from net/transport.go
   337  
   338  // authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
   339  // and returns a host:port. The port 443 is added if needed.
   340  func authorityAddr(scheme string, authority string) (addr string) {
   341  	host, port, err := net.SplitHostPort(authority)
   342  	if err != nil { // authority didn't have a port
   343  		port = "443"
   344  		if scheme == "http" {
   345  			port = "80"
   346  		}
   347  		host = authority
   348  	}
   349  	if a, err := idna.ToASCII(host); err == nil {
   350  		host = a
   351  	}
   352  	// IPv6 address literal, without a port:
   353  	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
   354  		return host + ":" + port
   355  	}
   356  	return net.JoinHostPort(host, port)
   357  }