github.com/tumi8/quic-go@v0.37.4-tum/http3/response_writer.go (about)

     1  package http3
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"net/http"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/tumi8/quic-go"
    13  	"github.com/tumi8/quic-go/noninternal/utils"
    14  	"github.com/quic-go/qpack"
    15  
    16  )
    17  
    18  type responseWriter struct {
    19  	conn        quic.Connection
    20  	str         quic.Stream
    21  	bufferedStr *bufio.Writer
    22  	buf         []byte
    23  
    24  	header        http.Header
    25  	status        int // status code passed to WriteHeader
    26  	headerWritten bool
    27  	contentLen    int64 // if handler set valid Content-Length header
    28  	numWritten    int64 // bytes written
    29  
    30  	logger utils.Logger
    31  }
    32  
    33  var (
    34  	_ http.ResponseWriter = &responseWriter{}
    35  	_ http.Flusher        = &responseWriter{}
    36  	_ Hijacker            = &responseWriter{}
    37  )
    38  
    39  func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
    40  	return &responseWriter{
    41  		header:      http.Header{},
    42  		buf:         make([]byte, 16),
    43  		conn:        conn,
    44  		str:         str,
    45  		bufferedStr: bufio.NewWriter(str),
    46  		logger:      logger,
    47  	}
    48  }
    49  
    50  func (w *responseWriter) Header() http.Header {
    51  	return w.header
    52  }
    53  
    54  func (w *responseWriter) WriteHeader(status int) {
    55  	if w.headerWritten {
    56  		return
    57  	}
    58  
    59  	// http status must be 3 digits
    60  	if status < 100 || status > 999 {
    61  		panic(fmt.Sprintf("invalid WriteHeader code %v", status))
    62  	}
    63  
    64  	if status >= 200 {
    65  		w.headerWritten = true
    66  		// Add Date header.
    67  		// This is what the standard library does.
    68  		// Can be disabled by setting the Date header to nil.
    69  		if _, ok := w.header["Date"]; !ok {
    70  			w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
    71  		}
    72  		// Content-Length checking
    73  		// use ParseUint instead of ParseInt, as negative values are invalid
    74  		if clen := w.header.Get("Content-Length"); clen != "" {
    75  			if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
    76  				w.contentLen = int64(cl)
    77  			} else {
    78  				// emit a warning for malformed Content-Length and remove it
    79  				w.logger.Errorf("Malformed Content-Length %s", clen)
    80  				w.header.Del("Content-Length")
    81  			}
    82  		}
    83  	}
    84  	w.status = status
    85  
    86  	var headers bytes.Buffer
    87  	enc := qpack.NewEncoder(&headers)
    88  	enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
    89  
    90  	for k, v := range w.header {
    91  		for index := range v {
    92  			enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
    93  		}
    94  	}
    95  
    96  	w.buf = w.buf[:0]
    97  	w.buf = (&headersFrame{Length: uint64(headers.Len())}).Append(w.buf)
    98  	w.logger.Infof("Responding with %d", status)
    99  	if _, err := w.bufferedStr.Write(w.buf); err != nil {
   100  		w.logger.Errorf("could not write headers frame: %s", err.Error())
   101  	}
   102  	if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil {
   103  		w.logger.Errorf("could not write header frame payload: %s", err.Error())
   104  	}
   105  	if !w.headerWritten {
   106  		w.Flush()
   107  	}
   108  }
   109  
   110  func (w *responseWriter) Write(p []byte) (int, error) {
   111  	bodyAllowed := bodyAllowedForStatus(w.status)
   112  	if !w.headerWritten {
   113  		// If body is not allowed, we don't need to (and we can't) sniff the content type.
   114  		if bodyAllowed {
   115  			// If no content type, apply sniffing algorithm to body.
   116  			// We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing.
   117  			_, haveType := w.header["Content-Type"]
   118  
   119  			// If the Transfer-Encoding or Content-Encoding was set and is non-blank,
   120  			// we shouldn't sniff the body.
   121  			hasTE := w.header.Get("Transfer-Encoding") != ""
   122  			hasCE := w.header.Get("Content-Encoding") != ""
   123  			if !hasCE && !haveType && !hasTE && len(p) > 0 {
   124  				w.header.Set("Content-Type", http.DetectContentType(p))
   125  			}
   126  		}
   127  		w.WriteHeader(http.StatusOK)
   128  		bodyAllowed = true
   129  	}
   130  	if !bodyAllowed {
   131  		return 0, http.ErrBodyNotAllowed
   132  	}
   133  
   134  	w.numWritten += int64(len(p))
   135  	if w.contentLen != 0 && w.numWritten > w.contentLen {
   136  		return 0, http.ErrContentLength
   137  	}
   138  
   139  	df := &dataFrame{Length: uint64(len(p))}
   140  	w.buf = w.buf[:0]
   141  	w.buf = df.Append(w.buf)
   142  	if _, err := w.bufferedStr.Write(w.buf); err != nil {
   143  		return 0, err
   144  	}
   145  	return w.bufferedStr.Write(p)
   146  }
   147  
   148  func (w *responseWriter) FlushError() error {
   149  	return w.bufferedStr.Flush()
   150  }
   151  
   152  func (w *responseWriter) Flush() {
   153  	if err := w.FlushError(); err != nil {
   154  		w.logger.Errorf("could not flush to stream: %s", err.Error())
   155  	}
   156  }
   157  
   158  func (w *responseWriter) StreamCreator() StreamCreator {
   159  	return w.conn
   160  }
   161  
   162  func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
   163  	return w.str.SetReadDeadline(deadline)
   164  }
   165  
   166  func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
   167  	return w.str.SetWriteDeadline(deadline)
   168  }
   169  
   170  // copied from http2/http2.go
   171  // bodyAllowedForStatus reports whether a given response status code
   172  // permits a body. See RFC 2616, section 4.4.
   173  func bodyAllowedForStatus(status int) bool {
   174  	switch {
   175  	case status >= 100 && status <= 199:
   176  		return false
   177  	case status == http.StatusNoContent:
   178  		return false
   179  	case status == http.StatusNotModified:
   180  		return false
   181  	}
   182  	return true
   183  }