github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/http3/request_writer.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"strconv"
    11  	"strings"
    12  	"sync"
    13  
    14  	"golang.org/x/net/http/httpguts"
    15  	"golang.org/x/net/http2/hpack"
    16  	"golang.org/x/net/idna"
    17  
    18  	"github.com/danielpfeifer02/quic-go-prio-packs"
    19  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/utils"
    20  	"github.com/quic-go/qpack"
    21  )
    22  
    23  const bodyCopyBufferSize = 8 * 1024
    24  
    25  type requestWriter struct {
    26  	mutex     sync.Mutex
    27  	encoder   *qpack.Encoder
    28  	headerBuf *bytes.Buffer
    29  
    30  	logger utils.Logger
    31  }
    32  
    33  func newRequestWriter(logger utils.Logger) *requestWriter {
    34  	headerBuf := &bytes.Buffer{}
    35  	encoder := qpack.NewEncoder(headerBuf)
    36  	return &requestWriter{
    37  		encoder:   encoder,
    38  		headerBuf: headerBuf,
    39  		logger:    logger,
    40  	}
    41  }
    42  
    43  func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool) error {
    44  	// TODO: figure out how to add support for trailers
    45  	buf := &bytes.Buffer{}
    46  	if err := w.writeHeaders(buf, req, gzip); err != nil {
    47  		return err
    48  	}
    49  	_, err := str.Write(buf.Bytes())
    50  	return err
    51  }
    52  
    53  func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
    54  	w.mutex.Lock()
    55  	defer w.mutex.Unlock()
    56  	defer w.encoder.Close()
    57  	defer w.headerBuf.Reset()
    58  
    59  	if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
    60  		return err
    61  	}
    62  
    63  	b := make([]byte, 0, 128)
    64  	b = (&headersFrame{Length: uint64(w.headerBuf.Len())}).Append(b)
    65  	if _, err := wr.Write(b); err != nil {
    66  		return err
    67  	}
    68  	_, err := wr.Write(w.headerBuf.Bytes())
    69  	return err
    70  }
    71  
    72  // copied from net/transport.go
    73  // Modified to support Extended CONNECT:
    74  // Contrary to what the godoc for the http.Request says,
    75  // we do respect the Proto field if the method is CONNECT.
    76  func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) error {
    77  	host := req.Host
    78  	if host == "" {
    79  		host = req.URL.Host
    80  	}
    81  	host, err := httpguts.PunycodeHostPort(host)
    82  	if err != nil {
    83  		return err
    84  	}
    85  	if !httpguts.ValidHostHeader(host) {
    86  		return errors.New("http3: invalid Host header")
    87  	}
    88  
    89  	// http.NewRequest sets this field to HTTP/1.1
    90  	isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1"
    91  
    92  	var path string
    93  	if req.Method != http.MethodConnect || isExtendedConnect {
    94  		path = req.URL.RequestURI()
    95  		if !validPseudoPath(path) {
    96  			orig := path
    97  			path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
    98  			if !validPseudoPath(path) {
    99  				if req.URL.Opaque != "" {
   100  					return fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
   101  				} else {
   102  					return fmt.Errorf("invalid request :path %q", orig)
   103  				}
   104  			}
   105  		}
   106  	}
   107  
   108  	// Check for any invalid headers and return an error before we
   109  	// potentially pollute our hpack state. (We want to be able to
   110  	// continue to reuse the hpack encoder for future requests)
   111  	for k, vv := range req.Header {
   112  		if !httpguts.ValidHeaderFieldName(k) {
   113  			return fmt.Errorf("invalid HTTP header name %q", k)
   114  		}
   115  		for _, v := range vv {
   116  			if !httpguts.ValidHeaderFieldValue(v) {
   117  				return fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
   118  			}
   119  		}
   120  	}
   121  
   122  	enumerateHeaders := func(f func(name, value string)) {
   123  		// 8.1.2.3 Request Pseudo-Header Fields
   124  		// The :path pseudo-header field includes the path and query parts of the
   125  		// target URI (the path-absolute production and optionally a '?' character
   126  		// followed by the query production (see Sections 3.3 and 3.4 of
   127  		// [RFC3986]).
   128  		f(":authority", host)
   129  		f(":method", req.Method)
   130  		if req.Method != http.MethodConnect || isExtendedConnect {
   131  			f(":path", path)
   132  			f(":scheme", req.URL.Scheme)
   133  		}
   134  		if isExtendedConnect {
   135  			f(":protocol", req.Proto)
   136  		}
   137  		if trailers != "" {
   138  			f("trailer", trailers)
   139  		}
   140  
   141  		var didUA bool
   142  		for k, vv := range req.Header {
   143  			if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
   144  				// Host is :authority, already sent.
   145  				// Content-Length is automatic, set below.
   146  				continue
   147  			} else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
   148  				strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
   149  				strings.EqualFold(k, "keep-alive") {
   150  				// Per 8.1.2.2 Connection-Specific Header
   151  				// Fields, don't send connection-specific
   152  				// fields. We have already checked if any
   153  				// are error-worthy so just ignore the rest.
   154  				continue
   155  			} else if strings.EqualFold(k, "user-agent") {
   156  				// Match Go's http1 behavior: at most one
   157  				// User-Agent. If set to nil or empty string,
   158  				// then omit it. Otherwise if not mentioned,
   159  				// include the default (below).
   160  				didUA = true
   161  				if len(vv) < 1 {
   162  					continue
   163  				}
   164  				vv = vv[:1]
   165  				if vv[0] == "" {
   166  					continue
   167  				}
   168  
   169  			}
   170  
   171  			for _, v := range vv {
   172  				f(k, v)
   173  			}
   174  		}
   175  		if shouldSendReqContentLength(req.Method, contentLength) {
   176  			f("content-length", strconv.FormatInt(contentLength, 10))
   177  		}
   178  		if addGzipHeader {
   179  			f("accept-encoding", "gzip")
   180  		}
   181  		if !didUA {
   182  			f("user-agent", defaultUserAgent)
   183  		}
   184  	}
   185  
   186  	// Do a first pass over the headers counting bytes to ensure
   187  	// we don't exceed cc.peerMaxHeaderListSize. This is done as a
   188  	// separate pass before encoding the headers to prevent
   189  	// modifying the hpack state.
   190  	hlSize := uint64(0)
   191  	enumerateHeaders(func(name, value string) {
   192  		hf := hpack.HeaderField{Name: name, Value: value}
   193  		hlSize += uint64(hf.Size())
   194  	})
   195  
   196  	// TODO: check maximum header list size
   197  	// if hlSize > cc.peerMaxHeaderListSize {
   198  	// 	return errRequestHeaderListSize
   199  	// }
   200  
   201  	// trace := httptrace.ContextClientTrace(req.Context())
   202  	// traceHeaders := traceHasWroteHeaderField(trace)
   203  
   204  	// Header list size is ok. Write the headers.
   205  	enumerateHeaders(func(name, value string) {
   206  		name = strings.ToLower(name)
   207  		w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value})
   208  		// if traceHeaders {
   209  		// 	traceWroteHeaderField(trace, name, value)
   210  		// }
   211  	})
   212  
   213  	return nil
   214  }
   215  
   216  // authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
   217  // and returns a host:port. The port 443 is added if needed.
   218  func authorityAddr(scheme string, authority string) (addr string) {
   219  	host, port, err := net.SplitHostPort(authority)
   220  	if err != nil { // authority didn't have a port
   221  		port = "443"
   222  		if scheme == "http" {
   223  			port = "80"
   224  		}
   225  		host = authority
   226  	}
   227  	if a, err := idna.ToASCII(host); err == nil {
   228  		host = a
   229  	}
   230  	// IPv6 address literal, without a port:
   231  	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
   232  		return host + ":" + port
   233  	}
   234  	return net.JoinHostPort(host, port)
   235  }
   236  
   237  // validPseudoPath reports whether v is a valid :path pseudo-header
   238  // value. It must be either:
   239  //
   240  //	*) a non-empty string starting with '/'
   241  //	*) the string '*', for OPTIONS requests.
   242  //
   243  // For now this is only used a quick check for deciding when to clean
   244  // up Opaque URLs before sending requests from the Transport.
   245  // See golang.org/issue/16847
   246  //
   247  // We used to enforce that the path also didn't start with "//", but
   248  // Google's GFE accepts such paths and Chrome sends them, so ignore
   249  // that part of the spec. See golang.org/issue/19103.
   250  func validPseudoPath(v string) bool {
   251  	return (len(v) > 0 && v[0] == '/') || v == "*"
   252  }
   253  
   254  // actualContentLength returns a sanitized version of
   255  // req.ContentLength, where 0 actually means zero (not unknown) and -1
   256  // means unknown.
   257  func actualContentLength(req *http.Request) int64 {
   258  	if req.Body == nil {
   259  		return 0
   260  	}
   261  	if req.ContentLength != 0 {
   262  		return req.ContentLength
   263  	}
   264  	return -1
   265  }
   266  
   267  // shouldSendReqContentLength reports whether the http2.Transport should send
   268  // a "content-length" request header. This logic is basically a copy of the net/http
   269  // transferWriter.shouldSendContentLength.
   270  // The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
   271  // -1 means unknown.
   272  func shouldSendReqContentLength(method string, contentLength int64) bool {
   273  	if contentLength > 0 {
   274  		return true
   275  	}
   276  	if contentLength < 0 {
   277  		return false
   278  	}
   279  	// For zero bodies, whether we send a content-length depends on the method.
   280  	// It also kinda doesn't matter for http2 either way, with END_STREAM.
   281  	switch method {
   282  	case "POST", "PUT", "PATCH":
   283  		return true
   284  	default:
   285  		return false
   286  	}
   287  }