github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/response_writer.go (about)

     1  package http3
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"net/http"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/daeuniverse/quic-go"
    13  	"github.com/daeuniverse/quic-go/internal/utils"
    14  
    15  	"github.com/quic-go/qpack"
    16  )
    17  
    18  // The maximum length of an encoded HTTP/3 frame header is 16:
    19  // The frame has a type and length field, both QUIC varints (maximum 8 bytes in length)
    20  const frameHeaderLen = 16
    21  
    22  // headerWriter wraps the stream, so that the first Write call flushes the header to the stream
    23  type headerWriter struct {
    24  	str     quic.Stream
    25  	header  http.Header
    26  	status  int // status code passed to WriteHeader
    27  	written bool
    28  
    29  	logger utils.Logger
    30  }
    31  
    32  // writeHeader encodes and flush header to the stream
    33  func (hw *headerWriter) writeHeader() error {
    34  	var headers bytes.Buffer
    35  	enc := qpack.NewEncoder(&headers)
    36  	enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)})
    37  
    38  	for k, v := range hw.header {
    39  		for index := range v {
    40  			enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
    41  		}
    42  	}
    43  
    44  	buf := make([]byte, 0, frameHeaderLen+headers.Len())
    45  	buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
    46  	hw.logger.Infof("Responding with %d", hw.status)
    47  	buf = append(buf, headers.Bytes()...)
    48  
    49  	_, err := hw.str.Write(buf)
    50  	return err
    51  }
    52  
    53  // first Write will trigger flushing header
    54  func (hw *headerWriter) Write(p []byte) (int, error) {
    55  	if !hw.written {
    56  		if err := hw.writeHeader(); err != nil {
    57  			return 0, err
    58  		}
    59  		hw.written = true
    60  	}
    61  	return hw.str.Write(p)
    62  }
    63  
    64  type responseWriter struct {
    65  	*headerWriter
    66  	conn        quic.Connection
    67  	bufferedStr *bufio.Writer
    68  	buf         []byte
    69  
    70  	contentLen    int64 // if handler set valid Content-Length header
    71  	numWritten    int64 // bytes written
    72  	headerWritten bool
    73  	isHead        bool
    74  }
    75  
    76  var (
    77  	_ http.ResponseWriter = &responseWriter{}
    78  	_ http.Flusher        = &responseWriter{}
    79  	_ Hijacker            = &responseWriter{}
    80  )
    81  
    82  func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
    83  	hw := &headerWriter{
    84  		str:    str,
    85  		header: http.Header{},
    86  		logger: logger,
    87  	}
    88  	return &responseWriter{
    89  		headerWriter: hw,
    90  		buf:          make([]byte, frameHeaderLen),
    91  		conn:         conn,
    92  		bufferedStr:  bufio.NewWriter(hw),
    93  	}
    94  }
    95  
    96  func (w *responseWriter) Header() http.Header {
    97  	return w.header
    98  }
    99  
   100  func (w *responseWriter) WriteHeader(status int) {
   101  	if w.headerWritten {
   102  		return
   103  	}
   104  
   105  	// http status must be 3 digits
   106  	if status < 100 || status > 999 {
   107  		panic(fmt.Sprintf("invalid WriteHeader code %v", status))
   108  	}
   109  
   110  	if status >= 200 {
   111  		w.headerWritten = true
   112  		// Add Date header.
   113  		// This is what the standard library does.
   114  		// Can be disabled by setting the Date header to nil.
   115  		if _, ok := w.header["Date"]; !ok {
   116  			w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
   117  		}
   118  		// Content-Length checking
   119  		// use ParseUint instead of ParseInt, as negative values are invalid
   120  		if clen := w.header.Get("Content-Length"); clen != "" {
   121  			if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
   122  				w.contentLen = int64(cl)
   123  			} else {
   124  				// emit a warning for malformed Content-Length and remove it
   125  				w.logger.Errorf("Malformed Content-Length %s", clen)
   126  				w.header.Del("Content-Length")
   127  			}
   128  		}
   129  	}
   130  	w.status = status
   131  
   132  	if !w.headerWritten {
   133  		w.writeHeader()
   134  	}
   135  }
   136  
   137  func (w *responseWriter) Write(p []byte) (int, error) {
   138  	bodyAllowed := bodyAllowedForStatus(w.status)
   139  	if !w.headerWritten {
   140  		// If body is not allowed, we don't need to (and we can't) sniff the content type.
   141  		if bodyAllowed {
   142  			// If no content type, apply sniffing algorithm to body.
   143  			// We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing.
   144  			_, haveType := w.header["Content-Type"]
   145  
   146  			// If the Transfer-Encoding or Content-Encoding was set and is non-blank,
   147  			// we shouldn't sniff the body.
   148  			hasTE := w.header.Get("Transfer-Encoding") != ""
   149  			hasCE := w.header.Get("Content-Encoding") != ""
   150  			if !hasCE && !haveType && !hasTE && len(p) > 0 {
   151  				w.header.Set("Content-Type", http.DetectContentType(p))
   152  			}
   153  		}
   154  		w.WriteHeader(http.StatusOK)
   155  		bodyAllowed = true
   156  	}
   157  	if !bodyAllowed {
   158  		return 0, http.ErrBodyNotAllowed
   159  	}
   160  
   161  	w.numWritten += int64(len(p))
   162  	if w.contentLen != 0 && w.numWritten > w.contentLen {
   163  		return 0, http.ErrContentLength
   164  	}
   165  
   166  	if w.isHead {
   167  		return len(p), nil
   168  	}
   169  
   170  	df := &dataFrame{Length: uint64(len(p))}
   171  	w.buf = w.buf[:0]
   172  	w.buf = df.Append(w.buf)
   173  	if _, err := w.bufferedStr.Write(w.buf); err != nil {
   174  		return 0, maybeReplaceError(err)
   175  	}
   176  	n, err := w.bufferedStr.Write(p)
   177  	return n, maybeReplaceError(err)
   178  }
   179  
   180  func (w *responseWriter) FlushError() error {
   181  	if !w.headerWritten {
   182  		w.WriteHeader(http.StatusOK)
   183  	}
   184  	if !w.written {
   185  		if err := w.writeHeader(); err != nil {
   186  			return maybeReplaceError(err)
   187  		}
   188  		w.written = true
   189  	}
   190  	return w.bufferedStr.Flush()
   191  }
   192  
   193  func (w *responseWriter) Flush() {
   194  	if err := w.FlushError(); err != nil {
   195  		w.logger.Errorf("could not flush to stream: %s", err.Error())
   196  	}
   197  }
   198  
   199  func (w *responseWriter) StreamCreator() StreamCreator {
   200  	return w.conn
   201  }
   202  
   203  func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
   204  	return w.str.SetReadDeadline(deadline)
   205  }
   206  
   207  func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
   208  	return w.str.SetWriteDeadline(deadline)
   209  }
   210  
   211  // copied from http2/http2.go
   212  // bodyAllowedForStatus reports whether a given response status code
   213  // permits a body. See RFC 2616, section 4.4.
   214  func bodyAllowedForStatus(status int) bool {
   215  	switch {
   216  	case status >= 100 && status <= 199:
   217  		return false
   218  	case status == http.StatusNoContent:
   219  		return false
   220  	case status == http.StatusNotModified:
   221  		return false
   222  	}
   223  	return true
   224  }