github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/http3/response_writer.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "fmt" 6 "log/slog" 7 "net/http" 8 "strconv" 9 "strings" 10 "time" 11 12 "github.com/quic-go/qpack" 13 ) 14 15 // The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented the http.Response.Body. 16 // On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set. 17 // When a stream is taken over, it's the caller's responsibility to close the stream. 18 type HTTPStreamer interface { 19 HTTPStream() Stream 20 } 21 22 // The maximum length of an encoded HTTP/3 frame header is 16: 23 // The frame has a type and length field, both QUIC varints (maximum 8 bytes in length) 24 const frameHeaderLen = 16 25 26 const maxSmallResponseSize = 4096 27 28 type responseWriter struct { 29 str *stream 30 31 conn Connection 32 header http.Header 33 buf []byte 34 status int // status code passed to WriteHeader 35 36 // for responses smaller than maxSmallResponseSize, we buffer calls to Write, 37 // and automatically add the Content-Length header 38 smallResponseBuf []byte 39 40 contentLen int64 // if handler set valid Content-Length header 41 numWritten int64 // bytes written 42 headerComplete bool // set once WriteHeader is called with a status code >= 200 43 headerWritten bool // set once the response header has been serialized to the stream 44 isHead bool 45 46 hijacked bool // set on HTTPStream is called 47 48 logger *slog.Logger 49 } 50 51 var ( 52 _ http.ResponseWriter = &responseWriter{} 53 _ http.Flusher = &responseWriter{} 54 _ Hijacker = &responseWriter{} 55 _ HTTPStreamer = &responseWriter{} 56 ) 57 58 func newResponseWriter(str *stream, conn Connection, isHead bool, logger *slog.Logger) *responseWriter { 59 return &responseWriter{ 60 str: str, 61 conn: conn, 62 header: http.Header{}, 63 buf: make([]byte, frameHeaderLen), 64 isHead: isHead, 65 logger: logger, 66 } 67 } 68 69 func (w *responseWriter) Header() http.Header { 70 return w.header 71 } 72 73 func (w *responseWriter) WriteHeader(status int) { 74 if w.headerComplete { 75 return 76 } 77 78 // http status must be 3 digits 79 if status < 100 || status > 999 { 80 panic(fmt.Sprintf("invalid WriteHeader code %v", status)) 81 } 82 w.status = status 83 84 // immediately write 1xx headers 85 if status < 200 { 86 w.writeHeader(status) 87 return 88 } 89 90 // We're done with headers once we write a status >= 200. 91 w.headerComplete = true 92 // Add Date header. 93 // This is what the standard library does. 94 // Can be disabled by setting the Date header to nil. 95 if _, ok := w.header["Date"]; !ok { 96 w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) 97 } 98 // Content-Length checking 99 // use ParseUint instead of ParseInt, as negative values are invalid 100 if clen := w.header.Get("Content-Length"); clen != "" { 101 if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { 102 w.contentLen = int64(cl) 103 } else { 104 // emit a warning for malformed Content-Length and remove it 105 logger := w.logger 106 if logger == nil { 107 logger = slog.Default() 108 } 109 logger.Error("Malformed Content-Length", "value", clen) 110 w.header.Del("Content-Length") 111 } 112 } 113 } 114 115 func (w *responseWriter) sniffContentType(p []byte) { 116 // If no content type, apply sniffing algorithm to body. 117 // We can't use `w.header.Get` here since if the Content-Type was set to nil, we shouldn't do sniffing. 118 _, haveType := w.header["Content-Type"] 119 120 // If the Transfer-Encoding or Content-Encoding was set and is non-blank, 121 // we shouldn't sniff the body. 122 hasTE := w.header.Get("Transfer-Encoding") != "" 123 hasCE := w.header.Get("Content-Encoding") != "" 124 if !hasCE && !haveType && !hasTE && len(p) > 0 { 125 w.header.Set("Content-Type", http.DetectContentType(p)) 126 } 127 } 128 129 func (w *responseWriter) Write(p []byte) (int, error) { 130 bodyAllowed := bodyAllowedForStatus(w.status) 131 if !w.headerComplete { 132 w.sniffContentType(p) 133 w.WriteHeader(http.StatusOK) 134 bodyAllowed = true 135 } 136 if !bodyAllowed { 137 return 0, http.ErrBodyNotAllowed 138 } 139 140 w.numWritten += int64(len(p)) 141 if w.contentLen != 0 && w.numWritten > w.contentLen { 142 return 0, http.ErrContentLength 143 } 144 145 if w.isHead { 146 return len(p), nil 147 } 148 149 if !w.headerWritten { 150 // Buffer small responses. 151 // This allows us to automatically set the Content-Length field. 152 if len(w.smallResponseBuf)+len(p) < maxSmallResponseSize { 153 w.smallResponseBuf = append(w.smallResponseBuf, p...) 154 return len(p), nil 155 } 156 } 157 return w.doWrite(p) 158 } 159 160 func (w *responseWriter) doWrite(p []byte) (int, error) { 161 if !w.headerWritten { 162 w.sniffContentType(w.smallResponseBuf) 163 if err := w.writeHeader(w.status); err != nil { 164 return 0, maybeReplaceError(err) 165 } 166 w.headerWritten = true 167 } 168 169 l := uint64(len(w.smallResponseBuf) + len(p)) 170 if l == 0 { 171 return 0, nil 172 } 173 df := &dataFrame{Length: l} 174 w.buf = w.buf[:0] 175 w.buf = df.Append(w.buf) 176 if _, err := w.str.writeUnframed(w.buf); err != nil { 177 return 0, maybeReplaceError(err) 178 } 179 if len(w.smallResponseBuf) > 0 { 180 if _, err := w.str.writeUnframed(w.smallResponseBuf); err != nil { 181 return 0, maybeReplaceError(err) 182 } 183 w.smallResponseBuf = nil 184 } 185 var n int 186 if len(p) > 0 { 187 var err error 188 n, err = w.str.writeUnframed(p) 189 if err != nil { 190 return n, maybeReplaceError(err) 191 } 192 } 193 return n, nil 194 } 195 196 func (w *responseWriter) writeHeader(status int) error { 197 var headers bytes.Buffer 198 enc := qpack.NewEncoder(&headers) 199 if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}); err != nil { 200 return err 201 } 202 203 for k, v := range w.header { 204 for index := range v { 205 if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil { 206 return err 207 } 208 } 209 } 210 211 buf := make([]byte, 0, frameHeaderLen+headers.Len()) 212 buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf) 213 buf = append(buf, headers.Bytes()...) 214 215 _, err := w.str.writeUnframed(buf) 216 return err 217 } 218 219 func (w *responseWriter) FlushError() error { 220 if !w.headerComplete { 221 w.WriteHeader(http.StatusOK) 222 } 223 _, err := w.doWrite(nil) 224 return err 225 } 226 227 func (w *responseWriter) Flush() { 228 if err := w.FlushError(); err != nil { 229 if w.logger != nil { 230 w.logger.Debug("could not flush to stream", "error", err) 231 } 232 } 233 } 234 235 func (w *responseWriter) HTTPStream() Stream { 236 w.hijacked = true 237 w.Flush() 238 return w.str 239 } 240 241 func (w *responseWriter) wasStreamHijacked() bool { return w.hijacked } 242 243 func (w *responseWriter) Connection() Connection { 244 return w.conn 245 } 246 247 func (w *responseWriter) SetReadDeadline(deadline time.Time) error { 248 return w.str.SetReadDeadline(deadline) 249 } 250 251 func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { 252 return w.str.SetWriteDeadline(deadline) 253 } 254 255 // copied from http2/http2.go 256 // bodyAllowedForStatus reports whether a given response status code 257 // permits a body. See RFC 2616, section 4.4. 258 func bodyAllowedForStatus(status int) bool { 259 switch { 260 case status >= 100 && status <= 199: 261 return false 262 case status == http.StatusNoContent: 263 return false 264 case status == http.StatusNotModified: 265 return false 266 } 267 return true 268 }