github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/http3/response_writer.go (about)

     1  package http3
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"net/http"
     7  	"strconv"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/mikelsr/quic-go"
    12  	"github.com/mikelsr/quic-go/internal/utils"
    13  
    14  	"github.com/quic-go/qpack"
    15  )
    16  
    17  type responseWriter struct {
    18  	conn        quic.Connection
    19  	str         quic.Stream
    20  	bufferedStr *bufio.Writer
    21  	buf         []byte
    22  
    23  	header        http.Header
    24  	status        int // status code passed to WriteHeader
    25  	headerWritten bool
    26  
    27  	logger utils.Logger
    28  }
    29  
    30  var (
    31  	_ http.ResponseWriter = &responseWriter{}
    32  	_ http.Flusher        = &responseWriter{}
    33  	_ Hijacker            = &responseWriter{}
    34  )
    35  
    36  func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
    37  	return &responseWriter{
    38  		header:      http.Header{},
    39  		buf:         make([]byte, 16),
    40  		conn:        conn,
    41  		str:         str,
    42  		bufferedStr: bufio.NewWriter(str),
    43  		logger:      logger,
    44  	}
    45  }
    46  
    47  func (w *responseWriter) Header() http.Header {
    48  	return w.header
    49  }
    50  
    51  func (w *responseWriter) WriteHeader(status int) {
    52  	if w.headerWritten {
    53  		return
    54  	}
    55  
    56  	if status < 100 || status >= 200 {
    57  		w.headerWritten = true
    58  	}
    59  	w.status = status
    60  
    61  	var headers bytes.Buffer
    62  	enc := qpack.NewEncoder(&headers)
    63  	enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
    64  
    65  	for k, v := range w.header {
    66  		for index := range v {
    67  			enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
    68  		}
    69  	}
    70  
    71  	w.buf = w.buf[:0]
    72  	w.buf = (&headersFrame{Length: uint64(headers.Len())}).Append(w.buf)
    73  	w.logger.Infof("Responding with %d", status)
    74  	if _, err := w.bufferedStr.Write(w.buf); err != nil {
    75  		w.logger.Errorf("could not write headers frame: %s", err.Error())
    76  	}
    77  	if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil {
    78  		w.logger.Errorf("could not write header frame payload: %s", err.Error())
    79  	}
    80  	if !w.headerWritten {
    81  		w.Flush()
    82  	}
    83  }
    84  
    85  func (w *responseWriter) Write(p []byte) (int, error) {
    86  	bodyAllowed := bodyAllowedForStatus(w.status)
    87  	if !w.headerWritten {
    88  		// If body is not allowed, we don't need to (and we can't) sniff the content type.
    89  		if bodyAllowed {
    90  			// If no content type, apply sniffing algorithm to body.
    91  			// We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing.
    92  			_, haveType := w.header["Content-Type"]
    93  
    94  			// If the Transfer-Encoding or Content-Encoding was set and is non-blank,
    95  			// we shouldn't sniff the body.
    96  			hasTE := w.header.Get("Transfer-Encoding") != ""
    97  			hasCE := w.header.Get("Content-Encoding") != ""
    98  			if !hasCE && !haveType && !hasTE && len(p) > 0 {
    99  				w.header.Set("Content-Type", http.DetectContentType(p))
   100  			}
   101  		}
   102  		w.WriteHeader(http.StatusOK)
   103  		bodyAllowed = true
   104  	}
   105  	if !bodyAllowed {
   106  		return 0, http.ErrBodyNotAllowed
   107  	}
   108  	df := &dataFrame{Length: uint64(len(p))}
   109  	w.buf = w.buf[:0]
   110  	w.buf = df.Append(w.buf)
   111  	if _, err := w.bufferedStr.Write(w.buf); err != nil {
   112  		return 0, err
   113  	}
   114  	return w.bufferedStr.Write(p)
   115  }
   116  
   117  func (w *responseWriter) Flush() {
   118  	if err := w.bufferedStr.Flush(); err != nil {
   119  		w.logger.Errorf("could not flush to stream: %s", err.Error())
   120  	}
   121  }
   122  
   123  func (w *responseWriter) StreamCreator() StreamCreator {
   124  	return w.conn
   125  }
   126  
   127  func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
   128  	return w.str.SetReadDeadline(deadline)
   129  }
   130  
   131  func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
   132  	return w.str.SetWriteDeadline(deadline)
   133  }
   134  
   135  // copied from http2/http2.go
   136  // bodyAllowedForStatus reports whether a given response status code
   137  // permits a body. See RFC 2616, section 4.4.
   138  func bodyAllowedForStatus(status int) bool {
   139  	switch {
   140  	case status >= 100 && status <= 199:
   141  		return false
   142  	case status == http.StatusNoContent:
   143  		return false
   144  	case status == http.StatusNotModified:
   145  		return false
   146  	}
   147  	return true
   148  }