github.com/MerlinKodo/quic-go@v0.39.2/http3/response_writer.go (about)

     1  package http3
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"net/http"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/MerlinKodo/quic-go"
    13  	"github.com/MerlinKodo/quic-go/internal/utils"
    14  
    15  	"github.com/quic-go/qpack"
    16  )
    17  
    18  // The maximum length of an encoded HTTP/3 frame header is 16:
    19  // The frame has a type and length field, both QUIC varints (maximum 8 bytes in length)
    20  const frameHeaderLen = 16
    21  
    22  // headerWriter wraps the stream, so that the first Write call flushes the header to the stream
    23  type headerWriter struct {
    24  	str     quic.Stream
    25  	header  http.Header
    26  	status  int // status code passed to WriteHeader
    27  	written bool
    28  
    29  	logger utils.Logger
    30  }
    31  
    32  // writeHeader encodes and flush header to the stream
    33  func (hw *headerWriter) writeHeader() error {
    34  	var headers bytes.Buffer
    35  	enc := qpack.NewEncoder(&headers)
    36  	enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)})
    37  
    38  	for k, v := range hw.header {
    39  		for index := range v {
    40  			enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
    41  		}
    42  	}
    43  
    44  	buf := make([]byte, 0, frameHeaderLen+headers.Len())
    45  	buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
    46  	hw.logger.Infof("Responding with %d", hw.status)
    47  	buf = append(buf, headers.Bytes()...)
    48  
    49  	_, err := hw.str.Write(buf)
    50  	return err
    51  }
    52  
    53  // first Write will trigger flushing header
    54  func (hw *headerWriter) Write(p []byte) (int, error) {
    55  	if !hw.written {
    56  		if err := hw.writeHeader(); err != nil {
    57  			return 0, err
    58  		}
    59  		hw.written = true
    60  	}
    61  	return hw.str.Write(p)
    62  }
    63  
    64  type responseWriter struct {
    65  	*headerWriter
    66  	conn        quic.Connection
    67  	bufferedStr *bufio.Writer
    68  	buf         []byte
    69  
    70  	headerWritten bool
    71  	contentLen    int64 // if handler set valid Content-Length header
    72  	numWritten    int64 // bytes written
    73  }
    74  
    75  var (
    76  	_ http.ResponseWriter = &responseWriter{}
    77  	_ http.Flusher        = &responseWriter{}
    78  	_ Hijacker            = &responseWriter{}
    79  )
    80  
    81  func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
    82  	hw := &headerWriter{
    83  		str:    str,
    84  		header: http.Header{},
    85  		logger: logger,
    86  	}
    87  	return &responseWriter{
    88  		headerWriter: hw,
    89  		buf:          make([]byte, frameHeaderLen),
    90  		conn:         conn,
    91  		bufferedStr:  bufio.NewWriter(hw),
    92  	}
    93  }
    94  
    95  func (w *responseWriter) Header() http.Header {
    96  	return w.header
    97  }
    98  
    99  func (w *responseWriter) WriteHeader(status int) {
   100  	if w.headerWritten {
   101  		return
   102  	}
   103  
   104  	// http status must be 3 digits
   105  	if status < 100 || status > 999 {
   106  		panic(fmt.Sprintf("invalid WriteHeader code %v", status))
   107  	}
   108  
   109  	if status >= 200 {
   110  		w.headerWritten = true
   111  		// Add Date header.
   112  		// This is what the standard library does.
   113  		// Can be disabled by setting the Date header to nil.
   114  		if _, ok := w.header["Date"]; !ok {
   115  			w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
   116  		}
   117  		// Content-Length checking
   118  		// use ParseUint instead of ParseInt, as negative values are invalid
   119  		if clen := w.header.Get("Content-Length"); clen != "" {
   120  			if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
   121  				w.contentLen = int64(cl)
   122  			} else {
   123  				// emit a warning for malformed Content-Length and remove it
   124  				w.logger.Errorf("Malformed Content-Length %s", clen)
   125  				w.header.Del("Content-Length")
   126  			}
   127  		}
   128  	}
   129  	w.status = status
   130  
   131  	if !w.headerWritten {
   132  		w.writeHeader()
   133  	}
   134  }
   135  
   136  func (w *responseWriter) Write(p []byte) (int, error) {
   137  	bodyAllowed := bodyAllowedForStatus(w.status)
   138  	if !w.headerWritten {
   139  		// If body is not allowed, we don't need to (and we can't) sniff the content type.
   140  		if bodyAllowed {
   141  			// If no content type, apply sniffing algorithm to body.
   142  			// We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing.
   143  			_, haveType := w.header["Content-Type"]
   144  
   145  			// If the Transfer-Encoding or Content-Encoding was set and is non-blank,
   146  			// we shouldn't sniff the body.
   147  			hasTE := w.header.Get("Transfer-Encoding") != ""
   148  			hasCE := w.header.Get("Content-Encoding") != ""
   149  			if !hasCE && !haveType && !hasTE && len(p) > 0 {
   150  				w.header.Set("Content-Type", http.DetectContentType(p))
   151  			}
   152  		}
   153  		w.WriteHeader(http.StatusOK)
   154  		bodyAllowed = true
   155  	}
   156  	if !bodyAllowed {
   157  		return 0, http.ErrBodyNotAllowed
   158  	}
   159  
   160  	w.numWritten += int64(len(p))
   161  	if w.contentLen != 0 && w.numWritten > w.contentLen {
   162  		return 0, http.ErrContentLength
   163  	}
   164  
   165  	df := &dataFrame{Length: uint64(len(p))}
   166  	w.buf = w.buf[:0]
   167  	w.buf = df.Append(w.buf)
   168  	if _, err := w.bufferedStr.Write(w.buf); err != nil {
   169  		return 0, maybeReplaceError(err)
   170  	}
   171  	n, err := w.bufferedStr.Write(p)
   172  	return n, maybeReplaceError(err)
   173  }
   174  
   175  func (w *responseWriter) FlushError() error {
   176  	if !w.headerWritten {
   177  		w.WriteHeader(http.StatusOK)
   178  	}
   179  	if !w.written {
   180  		if err := w.writeHeader(); err != nil {
   181  			return maybeReplaceError(err)
   182  		}
   183  		w.written = true
   184  	}
   185  	return w.bufferedStr.Flush()
   186  }
   187  
   188  func (w *responseWriter) Flush() {
   189  	if err := w.FlushError(); err != nil {
   190  		w.logger.Errorf("could not flush to stream: %s", err.Error())
   191  	}
   192  }
   193  
   194  func (w *responseWriter) StreamCreator() StreamCreator {
   195  	return w.conn
   196  }
   197  
   198  func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
   199  	return w.str.SetReadDeadline(deadline)
   200  }
   201  
   202  func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
   203  	return w.str.SetWriteDeadline(deadline)
   204  }
   205  
   206  // copied from http2/http2.go
   207  // bodyAllowedForStatus reports whether a given response status code
   208  // permits a body. See RFC 2616, section 4.4.
   209  func bodyAllowedForStatus(status int) bool {
   210  	switch {
   211  	case status >= 100 && status <= 199:
   212  		return false
   213  	case status == http.StatusNoContent:
   214  		return false
   215  	case status == http.StatusNotModified:
   216  		return false
   217  	}
   218  	return true
   219  }