github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/http3/request_writer.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"strconv"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go"
    14  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils"
    15  	"github.com/marten-seemann/qpack"
    16  	"golang.org/x/net/http/httpguts"
    17  	"golang.org/x/net/http2/hpack"
    18  	"golang.org/x/net/idna"
    19  )
    20  
    21  const bodyCopyBufferSize = 8 * 1024
    22  
    23  type requestWriter struct {
    24  	mutex     sync.Mutex
    25  	encoder   *qpack.Encoder
    26  	headerBuf *bytes.Buffer
    27  
    28  	logger utils.Logger
    29  }
    30  
    31  func newRequestWriter(logger utils.Logger) *requestWriter {
    32  	headerBuf := &bytes.Buffer{}
    33  	encoder := qpack.NewEncoder(headerBuf)
    34  	return &requestWriter{
    35  		encoder:   encoder,
    36  		headerBuf: headerBuf,
    37  		logger:    logger,
    38  	}
    39  }
    40  
    41  func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error {
    42  	buf := &bytes.Buffer{}
    43  	if err := w.writeHeaders(buf, req, gzip); err != nil {
    44  		return err
    45  	}
    46  	if _, err := str.Write(buf.Bytes()); err != nil {
    47  		return err
    48  	}
    49  	// TODO: add support for trailers
    50  	if req.Body == nil {
    51  		str.Close()
    52  		return nil
    53  	}
    54  
    55  	// send the request body asynchronously
    56  	go func() {
    57  		defer req.Body.Close()
    58  		b := make([]byte, bodyCopyBufferSize)
    59  		for {
    60  			n, rerr := req.Body.Read(b)
    61  			if n == 0 {
    62  				if rerr == nil {
    63  					continue
    64  				} else if rerr == io.EOF {
    65  					break
    66  				}
    67  			}
    68  			buf := &bytes.Buffer{}
    69  			(&dataFrame{Length: uint64(n)}).Write(buf)
    70  			if _, err := str.Write(buf.Bytes()); err != nil {
    71  				w.logger.Errorf("Error writing request: %s", err)
    72  				return
    73  			}
    74  			if _, err := str.Write(b[:n]); err != nil {
    75  				w.logger.Errorf("Error writing request: %s", err)
    76  				return
    77  			}
    78  			if rerr != nil {
    79  				if rerr == io.EOF {
    80  					break
    81  				}
    82  				str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
    83  				w.logger.Errorf("Error writing request: %s", rerr)
    84  				return
    85  			}
    86  		}
    87  		str.Close()
    88  	}()
    89  
    90  	return nil
    91  }
    92  
    93  func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
    94  	w.mutex.Lock()
    95  	defer w.mutex.Unlock()
    96  	defer w.encoder.Close()
    97  
    98  	if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
    99  		return err
   100  	}
   101  
   102  	buf := &bytes.Buffer{}
   103  	hf := headersFrame{Length: uint64(w.headerBuf.Len())}
   104  	hf.Write(buf)
   105  	if _, err := wr.Write(buf.Bytes()); err != nil {
   106  		return err
   107  	}
   108  	if _, err := wr.Write(w.headerBuf.Bytes()); err != nil {
   109  		return err
   110  	}
   111  	w.headerBuf.Reset()
   112  	return nil
   113  }
   114  
   115  // copied from net/transport.go
   116  
   117  func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) error {
   118  	host := req.Host
   119  	if host == "" {
   120  		host = req.URL.Host
   121  	}
   122  	host, err := httpguts.PunycodeHostPort(host)
   123  	if err != nil {
   124  		return err
   125  	}
   126  
   127  	var path string
   128  	if req.Method != "CONNECT" {
   129  		path = req.URL.RequestURI()
   130  		if !validPseudoPath(path) {
   131  			orig := path
   132  			path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
   133  			if !validPseudoPath(path) {
   134  				if req.URL.Opaque != "" {
   135  					return fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
   136  				} else {
   137  					return fmt.Errorf("invalid request :path %q", orig)
   138  				}
   139  			}
   140  		}
   141  	}
   142  
   143  	// Check for any invalid headers and return an error before we
   144  	// potentially pollute our hpack state. (We want to be able to
   145  	// continue to reuse the hpack encoder for future requests)
   146  	for k, vv := range req.Header {
   147  		if !httpguts.ValidHeaderFieldName(k) {
   148  			return fmt.Errorf("invalid HTTP header name %q", k)
   149  		}
   150  		for _, v := range vv {
   151  			if !httpguts.ValidHeaderFieldValue(v) {
   152  				return fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
   153  			}
   154  		}
   155  	}
   156  
   157  	enumerateHeaders := func(f func(name, value string)) {
   158  		// 8.1.2.3 Request Pseudo-Header Fields
   159  		// The :path pseudo-header field includes the path and query parts of the
   160  		// target URI (the path-absolute production and optionally a '?' character
   161  		// followed by the query production (see Sections 3.3 and 3.4 of
   162  		// [RFC3986]).
   163  		f(":authority", host)
   164  		f(":method", req.Method)
   165  		if req.Method != "CONNECT" {
   166  			f(":path", path)
   167  			f(":scheme", req.URL.Scheme)
   168  		}
   169  		if trailers != "" {
   170  			f("trailer", trailers)
   171  		}
   172  
   173  		var didUA bool
   174  		for k, vv := range req.Header {
   175  			if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
   176  				// Host is :authority, already sent.
   177  				// Content-Length is automatic, set below.
   178  				continue
   179  			} else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
   180  				strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
   181  				strings.EqualFold(k, "keep-alive") {
   182  				// Per 8.1.2.2 Connection-Specific Header
   183  				// Fields, don't send connection-specific
   184  				// fields. We have already checked if any
   185  				// are error-worthy so just ignore the rest.
   186  				continue
   187  			} else if strings.EqualFold(k, "user-agent") {
   188  				// Match Go's http1 behavior: at most one
   189  				// User-Agent. If set to nil or empty string,
   190  				// then omit it. Otherwise if not mentioned,
   191  				// include the default (below).
   192  				didUA = true
   193  				if len(vv) < 1 {
   194  					continue
   195  				}
   196  				vv = vv[:1]
   197  				if vv[0] == "" {
   198  					continue
   199  				}
   200  
   201  			}
   202  
   203  			for _, v := range vv {
   204  				f(k, v)
   205  			}
   206  		}
   207  		if shouldSendReqContentLength(req.Method, contentLength) {
   208  			f("content-length", strconv.FormatInt(contentLength, 10))
   209  		}
   210  		if addGzipHeader {
   211  			f("accept-encoding", "gzip")
   212  		}
   213  		if !didUA {
   214  			f("user-agent", defaultUserAgent)
   215  		}
   216  	}
   217  
   218  	// Do a first pass over the headers counting bytes to ensure
   219  	// we don't exceed cc.peerMaxHeaderListSize. This is done as a
   220  	// separate pass before encoding the headers to prevent
   221  	// modifying the hpack state.
   222  	hlSize := uint64(0)
   223  	enumerateHeaders(func(name, value string) {
   224  		hf := hpack.HeaderField{Name: name, Value: value}
   225  		hlSize += uint64(hf.Size())
   226  	})
   227  
   228  	// TODO: check maximum header list size
   229  	// if hlSize > cc.peerMaxHeaderListSize {
   230  	// 	return errRequestHeaderListSize
   231  	// }
   232  
   233  	// trace := httptrace.ContextClientTrace(req.Context())
   234  	// traceHeaders := traceHasWroteHeaderField(trace)
   235  
   236  	// Header list size is ok. Write the headers.
   237  	enumerateHeaders(func(name, value string) {
   238  		name = strings.ToLower(name)
   239  		w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value})
   240  		// if traceHeaders {
   241  		// 	traceWroteHeaderField(trace, name, value)
   242  		// }
   243  	})
   244  
   245  	return nil
   246  }
   247  
   248  // authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
   249  // and returns a host:port. The port 443 is added if needed.
   250  func authorityAddr(scheme string, authority string) (addr string) {
   251  	host, port, err := net.SplitHostPort(authority)
   252  	if err != nil { // authority didn't have a port
   253  		port = "443"
   254  		if scheme == "http" {
   255  			port = "80"
   256  		}
   257  		host = authority
   258  	}
   259  	if a, err := idna.ToASCII(host); err == nil {
   260  		host = a
   261  	}
   262  	// IPv6 address literal, without a port:
   263  	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
   264  		return host + ":" + port
   265  	}
   266  	return net.JoinHostPort(host, port)
   267  }
   268  
   269  // validPseudoPath reports whether v is a valid :path pseudo-header
   270  // value. It must be either:
   271  //
   272  //     *) a non-empty string starting with '/'
   273  //     *) the string '*', for OPTIONS requests.
   274  //
   275  // For now this is only used a quick check for deciding when to clean
   276  // up Opaque URLs before sending requests from the Transport.
   277  // See golang.org/issue/16847
   278  //
   279  // We used to enforce that the path also didn't start with "//", but
   280  // Google's GFE accepts such paths and Chrome sends them, so ignore
   281  // that part of the spec. See golang.org/issue/19103.
   282  func validPseudoPath(v string) bool {
   283  	return (len(v) > 0 && v[0] == '/') || v == "*"
   284  }
   285  
   286  // actualContentLength returns a sanitized version of
   287  // req.ContentLength, where 0 actually means zero (not unknown) and -1
   288  // means unknown.
   289  func actualContentLength(req *http.Request) int64 {
   290  	if req.Body == nil {
   291  		return 0
   292  	}
   293  	if req.ContentLength != 0 {
   294  		return req.ContentLength
   295  	}
   296  	return -1
   297  }
   298  
   299  // shouldSendReqContentLength reports whether the http2.Transport should send
   300  // a "content-length" request header. This logic is basically a copy of the net/http
   301  // transferWriter.shouldSendContentLength.
   302  // The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
   303  // -1 means unknown.
   304  func shouldSendReqContentLength(method string, contentLength int64) bool {
   305  	if contentLength > 0 {
   306  		return true
   307  	}
   308  	if contentLength < 0 {
   309  		return false
   310  	}
   311  	// For zero bodies, whether we send a content-length depends on the method.
   312  	// It also kinda doesn't matter for http2 either way, with END_STREAM.
   313  	switch method {
   314  	case "POST", "PUT", "PATCH":
   315  		return true
   316  	default:
   317  		return false
   318  	}
   319  }