github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/http3/response_writer.go (about) 1 package http3 2 3 import ( 4 "bufio" 5 "bytes" 6 "fmt" 7 "net/http" 8 "strconv" 9 "strings" 10 "time" 11 12 "github.com/danielpfeifer02/quic-go-prio-packs" 13 "github.com/danielpfeifer02/quic-go-prio-packs/internal/utils" 14 15 "github.com/quic-go/qpack" 16 ) 17 18 // The maximum length of an encoded HTTP/3 frame header is 16: 19 // The frame has a type and length field, both QUIC varints (maximum 8 bytes in length) 20 const frameHeaderLen = 16 21 22 // headerWriter wraps the stream, so that the first Write call flushes the header to the stream 23 type headerWriter struct { 24 str quic.Stream 25 header http.Header 26 status int // status code passed to WriteHeader 27 written bool 28 29 logger utils.Logger 30 } 31 32 // writeHeader encodes and flush header to the stream 33 func (hw *headerWriter) writeHeader() error { 34 var headers bytes.Buffer 35 enc := qpack.NewEncoder(&headers) 36 enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)}) 37 38 for k, v := range hw.header { 39 for index := range v { 40 enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) 41 } 42 } 43 44 buf := make([]byte, 0, frameHeaderLen+headers.Len()) 45 buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf) 46 hw.logger.Infof("Responding with %d", hw.status) 47 buf = append(buf, headers.Bytes()...) 48 49 _, err := hw.str.Write(buf) 50 return err 51 } 52 53 // first Write will trigger flushing header 54 func (hw *headerWriter) Write(p []byte) (int, error) { 55 if !hw.written { 56 if err := hw.writeHeader(); err != nil { 57 return 0, err 58 } 59 hw.written = true 60 } 61 return hw.str.Write(p) 62 } 63 64 type responseWriter struct { 65 *headerWriter 66 conn quic.Connection 67 bufferedStr *bufio.Writer 68 buf []byte 69 70 contentLen int64 // if handler set valid Content-Length header 71 numWritten int64 // bytes written 72 headerWritten bool 73 isHead bool 74 } 75 76 var ( 77 _ http.ResponseWriter = &responseWriter{} 78 _ http.Flusher = &responseWriter{} 79 _ Hijacker = &responseWriter{} 80 ) 81 82 func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { 83 hw := &headerWriter{ 84 str: str, 85 header: http.Header{}, 86 logger: logger, 87 } 88 return &responseWriter{ 89 headerWriter: hw, 90 buf: make([]byte, frameHeaderLen), 91 conn: conn, 92 bufferedStr: bufio.NewWriter(hw), 93 } 94 } 95 96 func (w *responseWriter) Header() http.Header { 97 return w.header 98 } 99 100 func (w *responseWriter) WriteHeader(status int) { 101 if w.headerWritten { 102 return 103 } 104 105 // http status must be 3 digits 106 if status < 100 || status > 999 { 107 panic(fmt.Sprintf("invalid WriteHeader code %v", status)) 108 } 109 110 if status >= 200 { 111 w.headerWritten = true 112 // Add Date header. 113 // This is what the standard library does. 114 // Can be disabled by setting the Date header to nil. 115 if _, ok := w.header["Date"]; !ok { 116 w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) 117 } 118 // Content-Length checking 119 // use ParseUint instead of ParseInt, as negative values are invalid 120 if clen := w.header.Get("Content-Length"); clen != "" { 121 if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { 122 w.contentLen = int64(cl) 123 } else { 124 // emit a warning for malformed Content-Length and remove it 125 w.logger.Errorf("Malformed Content-Length %s", clen) 126 w.header.Del("Content-Length") 127 } 128 } 129 } 130 w.status = status 131 132 if !w.headerWritten { 133 w.writeHeader() 134 } 135 } 136 137 func (w *responseWriter) Write(p []byte) (int, error) { 138 bodyAllowed := bodyAllowedForStatus(w.status) 139 if !w.headerWritten { 140 // If body is not allowed, we don't need to (and we can't) sniff the content type. 141 if bodyAllowed { 142 // If no content type, apply sniffing algorithm to body. 143 // We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing. 144 _, haveType := w.header["Content-Type"] 145 146 // If the Transfer-Encoding or Content-Encoding was set and is non-blank, 147 // we shouldn't sniff the body. 148 hasTE := w.header.Get("Transfer-Encoding") != "" 149 hasCE := w.header.Get("Content-Encoding") != "" 150 if !hasCE && !haveType && !hasTE && len(p) > 0 { 151 w.header.Set("Content-Type", http.DetectContentType(p)) 152 } 153 } 154 w.WriteHeader(http.StatusOK) 155 bodyAllowed = true 156 } 157 if !bodyAllowed { 158 return 0, http.ErrBodyNotAllowed 159 } 160 161 w.numWritten += int64(len(p)) 162 if w.contentLen != 0 && w.numWritten > w.contentLen { 163 return 0, http.ErrContentLength 164 } 165 166 if w.isHead { 167 return len(p), nil 168 } 169 170 df := &dataFrame{Length: uint64(len(p))} 171 w.buf = w.buf[:0] 172 w.buf = df.Append(w.buf) 173 if _, err := w.bufferedStr.Write(w.buf); err != nil { 174 return 0, maybeReplaceError(err) 175 } 176 n, err := w.bufferedStr.Write(p) 177 return n, maybeReplaceError(err) 178 } 179 180 func (w *responseWriter) FlushError() error { 181 if !w.headerWritten { 182 w.WriteHeader(http.StatusOK) 183 } 184 if !w.written { 185 if err := w.writeHeader(); err != nil { 186 return maybeReplaceError(err) 187 } 188 w.written = true 189 } 190 return w.bufferedStr.Flush() 191 } 192 193 func (w *responseWriter) Flush() { 194 if err := w.FlushError(); err != nil { 195 w.logger.Errorf("could not flush to stream: %s", err.Error()) 196 } 197 } 198 199 func (w *responseWriter) StreamCreator() StreamCreator { 200 return w.conn 201 } 202 203 func (w *responseWriter) SetReadDeadline(deadline time.Time) error { 204 return w.str.SetReadDeadline(deadline) 205 } 206 207 func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { 208 return w.str.SetWriteDeadline(deadline) 209 } 210 211 // copied from http2/http2.go 212 // bodyAllowedForStatus reports whether a given response status code 213 // permits a body. See RFC 2616, section 4.4. 214 func bodyAllowedForStatus(status int) bool { 215 switch { 216 case status >= 100 && status <= 199: 217 return false 218 case status == http.StatusNoContent: 219 return false 220 case status == http.StatusNotModified: 221 return false 222 } 223 return true 224 }