github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/http3/response_writer.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"log/slog"
     7  	"net/http"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/quic-go/qpack"
    13  )
    14  
    15  // The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented the http.Response.Body.
    16  // On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set.
    17  // When a stream is taken over, it's the caller's responsibility to close the stream.
    18  type HTTPStreamer interface {
    19  	HTTPStream() Stream
    20  }
    21  
    22  // The maximum length of an encoded HTTP/3 frame header is 16:
    23  // The frame has a type and length field, both QUIC varints (maximum 8 bytes in length)
    24  const frameHeaderLen = 16
    25  
    26  const maxSmallResponseSize = 4096
    27  
    28  type responseWriter struct {
    29  	str *stream
    30  
    31  	conn   Connection
    32  	header http.Header
    33  	buf    []byte
    34  	status int // status code passed to WriteHeader
    35  
    36  	// for responses smaller than maxSmallResponseSize, we buffer calls to Write,
    37  	// and automatically add the Content-Length header
    38  	smallResponseBuf []byte
    39  
    40  	contentLen     int64 // if handler set valid Content-Length header
    41  	numWritten     int64 // bytes written
    42  	headerComplete bool  // set once WriteHeader is called with a status code >= 200
    43  	headerWritten  bool  // set once the response header has been serialized to the stream
    44  	isHead         bool
    45  
    46  	hijacked bool // set on HTTPStream is called
    47  
    48  	logger *slog.Logger
    49  }
    50  
    51  var (
    52  	_ http.ResponseWriter = &responseWriter{}
    53  	_ http.Flusher        = &responseWriter{}
    54  	_ Hijacker            = &responseWriter{}
    55  	_ HTTPStreamer        = &responseWriter{}
    56  )
    57  
    58  func newResponseWriter(str *stream, conn Connection, isHead bool, logger *slog.Logger) *responseWriter {
    59  	return &responseWriter{
    60  		str:    str,
    61  		conn:   conn,
    62  		header: http.Header{},
    63  		buf:    make([]byte, frameHeaderLen),
    64  		isHead: isHead,
    65  		logger: logger,
    66  	}
    67  }
    68  
    69  func (w *responseWriter) Header() http.Header {
    70  	return w.header
    71  }
    72  
    73  func (w *responseWriter) WriteHeader(status int) {
    74  	if w.headerComplete {
    75  		return
    76  	}
    77  
    78  	// http status must be 3 digits
    79  	if status < 100 || status > 999 {
    80  		panic(fmt.Sprintf("invalid WriteHeader code %v", status))
    81  	}
    82  	w.status = status
    83  
    84  	// immediately write 1xx headers
    85  	if status < 200 {
    86  		w.writeHeader(status)
    87  		return
    88  	}
    89  
    90  	// We're done with headers once we write a status >= 200.
    91  	w.headerComplete = true
    92  	// Add Date header.
    93  	// This is what the standard library does.
    94  	// Can be disabled by setting the Date header to nil.
    95  	if _, ok := w.header["Date"]; !ok {
    96  		w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
    97  	}
    98  	// Content-Length checking
    99  	// use ParseUint instead of ParseInt, as negative values are invalid
   100  	if clen := w.header.Get("Content-Length"); clen != "" {
   101  		if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
   102  			w.contentLen = int64(cl)
   103  		} else {
   104  			// emit a warning for malformed Content-Length and remove it
   105  			logger := w.logger
   106  			if logger == nil {
   107  				logger = slog.Default()
   108  			}
   109  			logger.Error("Malformed Content-Length", "value", clen)
   110  			w.header.Del("Content-Length")
   111  		}
   112  	}
   113  }
   114  
   115  func (w *responseWriter) sniffContentType(p []byte) {
   116  	// If no content type, apply sniffing algorithm to body.
   117  	// We can't use `w.header.Get` here since if the Content-Type was set to nil, we shouldn't do sniffing.
   118  	_, haveType := w.header["Content-Type"]
   119  
   120  	// If the Transfer-Encoding or Content-Encoding was set and is non-blank,
   121  	// we shouldn't sniff the body.
   122  	hasTE := w.header.Get("Transfer-Encoding") != ""
   123  	hasCE := w.header.Get("Content-Encoding") != ""
   124  	if !hasCE && !haveType && !hasTE && len(p) > 0 {
   125  		w.header.Set("Content-Type", http.DetectContentType(p))
   126  	}
   127  }
   128  
   129  func (w *responseWriter) Write(p []byte) (int, error) {
   130  	bodyAllowed := bodyAllowedForStatus(w.status)
   131  	if !w.headerComplete {
   132  		w.sniffContentType(p)
   133  		w.WriteHeader(http.StatusOK)
   134  		bodyAllowed = true
   135  	}
   136  	if !bodyAllowed {
   137  		return 0, http.ErrBodyNotAllowed
   138  	}
   139  
   140  	w.numWritten += int64(len(p))
   141  	if w.contentLen != 0 && w.numWritten > w.contentLen {
   142  		return 0, http.ErrContentLength
   143  	}
   144  
   145  	if w.isHead {
   146  		return len(p), nil
   147  	}
   148  
   149  	if !w.headerWritten {
   150  		// Buffer small responses.
   151  		// This allows us to automatically set the Content-Length field.
   152  		if len(w.smallResponseBuf)+len(p) < maxSmallResponseSize {
   153  			w.smallResponseBuf = append(w.smallResponseBuf, p...)
   154  			return len(p), nil
   155  		}
   156  	}
   157  	return w.doWrite(p)
   158  }
   159  
   160  func (w *responseWriter) doWrite(p []byte) (int, error) {
   161  	if !w.headerWritten {
   162  		w.sniffContentType(w.smallResponseBuf)
   163  		if err := w.writeHeader(w.status); err != nil {
   164  			return 0, maybeReplaceError(err)
   165  		}
   166  		w.headerWritten = true
   167  	}
   168  
   169  	l := uint64(len(w.smallResponseBuf) + len(p))
   170  	if l == 0 {
   171  		return 0, nil
   172  	}
   173  	df := &dataFrame{Length: l}
   174  	w.buf = w.buf[:0]
   175  	w.buf = df.Append(w.buf)
   176  	if _, err := w.str.writeUnframed(w.buf); err != nil {
   177  		return 0, maybeReplaceError(err)
   178  	}
   179  	if len(w.smallResponseBuf) > 0 {
   180  		if _, err := w.str.writeUnframed(w.smallResponseBuf); err != nil {
   181  			return 0, maybeReplaceError(err)
   182  		}
   183  		w.smallResponseBuf = nil
   184  	}
   185  	var n int
   186  	if len(p) > 0 {
   187  		var err error
   188  		n, err = w.str.writeUnframed(p)
   189  		if err != nil {
   190  			return n, maybeReplaceError(err)
   191  		}
   192  	}
   193  	return n, nil
   194  }
   195  
   196  func (w *responseWriter) writeHeader(status int) error {
   197  	var headers bytes.Buffer
   198  	enc := qpack.NewEncoder(&headers)
   199  	if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}); err != nil {
   200  		return err
   201  	}
   202  
   203  	for k, v := range w.header {
   204  		for index := range v {
   205  			if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil {
   206  				return err
   207  			}
   208  		}
   209  	}
   210  
   211  	buf := make([]byte, 0, frameHeaderLen+headers.Len())
   212  	buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
   213  	buf = append(buf, headers.Bytes()...)
   214  
   215  	_, err := w.str.writeUnframed(buf)
   216  	return err
   217  }
   218  
   219  func (w *responseWriter) FlushError() error {
   220  	if !w.headerComplete {
   221  		w.WriteHeader(http.StatusOK)
   222  	}
   223  	_, err := w.doWrite(nil)
   224  	return err
   225  }
   226  
   227  func (w *responseWriter) Flush() {
   228  	if err := w.FlushError(); err != nil {
   229  		if w.logger != nil {
   230  			w.logger.Debug("could not flush to stream", "error", err)
   231  		}
   232  	}
   233  }
   234  
   235  func (w *responseWriter) HTTPStream() Stream {
   236  	w.hijacked = true
   237  	w.Flush()
   238  	return w.str
   239  }
   240  
   241  func (w *responseWriter) wasStreamHijacked() bool { return w.hijacked }
   242  
   243  func (w *responseWriter) Connection() Connection {
   244  	return w.conn
   245  }
   246  
   247  func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
   248  	return w.str.SetReadDeadline(deadline)
   249  }
   250  
   251  func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
   252  	return w.str.SetWriteDeadline(deadline)
   253  }
   254  
   255  // copied from http2/http2.go
   256  // bodyAllowedForStatus reports whether a given response status code
   257  // permits a body. See RFC 2616, section 4.4.
   258  func bodyAllowedForStatus(status int) bool {
   259  	switch {
   260  	case status >= 100 && status <= 199:
   261  		return false
   262  	case status == http.StatusNoContent:
   263  		return false
   264  	case status == http.StatusNotModified:
   265  		return false
   266  	}
   267  	return true
   268  }