github.com/MerlinKodo/quic-go@v0.39.2/http3/response_writer.go (about) 1 package http3 2 3 import ( 4 "bufio" 5 "bytes" 6 "fmt" 7 "net/http" 8 "strconv" 9 "strings" 10 "time" 11 12 "github.com/MerlinKodo/quic-go" 13 "github.com/MerlinKodo/quic-go/internal/utils" 14 15 "github.com/quic-go/qpack" 16 ) 17 18 // The maximum length of an encoded HTTP/3 frame header is 16: 19 // The frame has a type and length field, both QUIC varints (maximum 8 bytes in length) 20 const frameHeaderLen = 16 21 22 // headerWriter wraps the stream, so that the first Write call flushes the header to the stream 23 type headerWriter struct { 24 str quic.Stream 25 header http.Header 26 status int // status code passed to WriteHeader 27 written bool 28 29 logger utils.Logger 30 } 31 32 // writeHeader encodes and flush header to the stream 33 func (hw *headerWriter) writeHeader() error { 34 var headers bytes.Buffer 35 enc := qpack.NewEncoder(&headers) 36 enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)}) 37 38 for k, v := range hw.header { 39 for index := range v { 40 enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) 41 } 42 } 43 44 buf := make([]byte, 0, frameHeaderLen+headers.Len()) 45 buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf) 46 hw.logger.Infof("Responding with %d", hw.status) 47 buf = append(buf, headers.Bytes()...) 48 49 _, err := hw.str.Write(buf) 50 return err 51 } 52 53 // first Write will trigger flushing header 54 func (hw *headerWriter) Write(p []byte) (int, error) { 55 if !hw.written { 56 if err := hw.writeHeader(); err != nil { 57 return 0, err 58 } 59 hw.written = true 60 } 61 return hw.str.Write(p) 62 } 63 64 type responseWriter struct { 65 *headerWriter 66 conn quic.Connection 67 bufferedStr *bufio.Writer 68 buf []byte 69 70 headerWritten bool 71 contentLen int64 // if handler set valid Content-Length header 72 numWritten int64 // bytes written 73 } 74 75 var ( 76 _ http.ResponseWriter = &responseWriter{} 77 _ http.Flusher = &responseWriter{} 78 _ Hijacker = &responseWriter{} 79 ) 80 81 func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { 82 hw := &headerWriter{ 83 str: str, 84 header: http.Header{}, 85 logger: logger, 86 } 87 return &responseWriter{ 88 headerWriter: hw, 89 buf: make([]byte, frameHeaderLen), 90 conn: conn, 91 bufferedStr: bufio.NewWriter(hw), 92 } 93 } 94 95 func (w *responseWriter) Header() http.Header { 96 return w.header 97 } 98 99 func (w *responseWriter) WriteHeader(status int) { 100 if w.headerWritten { 101 return 102 } 103 104 // http status must be 3 digits 105 if status < 100 || status > 999 { 106 panic(fmt.Sprintf("invalid WriteHeader code %v", status)) 107 } 108 109 if status >= 200 { 110 w.headerWritten = true 111 // Add Date header. 112 // This is what the standard library does. 113 // Can be disabled by setting the Date header to nil. 114 if _, ok := w.header["Date"]; !ok { 115 w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) 116 } 117 // Content-Length checking 118 // use ParseUint instead of ParseInt, as negative values are invalid 119 if clen := w.header.Get("Content-Length"); clen != "" { 120 if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { 121 w.contentLen = int64(cl) 122 } else { 123 // emit a warning for malformed Content-Length and remove it 124 w.logger.Errorf("Malformed Content-Length %s", clen) 125 w.header.Del("Content-Length") 126 } 127 } 128 } 129 w.status = status 130 131 if !w.headerWritten { 132 w.writeHeader() 133 } 134 } 135 136 func (w *responseWriter) Write(p []byte) (int, error) { 137 bodyAllowed := bodyAllowedForStatus(w.status) 138 if !w.headerWritten { 139 // If body is not allowed, we don't need to (and we can't) sniff the content type. 140 if bodyAllowed { 141 // If no content type, apply sniffing algorithm to body. 142 // We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing. 143 _, haveType := w.header["Content-Type"] 144 145 // If the Transfer-Encoding or Content-Encoding was set and is non-blank, 146 // we shouldn't sniff the body. 147 hasTE := w.header.Get("Transfer-Encoding") != "" 148 hasCE := w.header.Get("Content-Encoding") != "" 149 if !hasCE && !haveType && !hasTE && len(p) > 0 { 150 w.header.Set("Content-Type", http.DetectContentType(p)) 151 } 152 } 153 w.WriteHeader(http.StatusOK) 154 bodyAllowed = true 155 } 156 if !bodyAllowed { 157 return 0, http.ErrBodyNotAllowed 158 } 159 160 w.numWritten += int64(len(p)) 161 if w.contentLen != 0 && w.numWritten > w.contentLen { 162 return 0, http.ErrContentLength 163 } 164 165 df := &dataFrame{Length: uint64(len(p))} 166 w.buf = w.buf[:0] 167 w.buf = df.Append(w.buf) 168 if _, err := w.bufferedStr.Write(w.buf); err != nil { 169 return 0, maybeReplaceError(err) 170 } 171 n, err := w.bufferedStr.Write(p) 172 return n, maybeReplaceError(err) 173 } 174 175 func (w *responseWriter) FlushError() error { 176 if !w.headerWritten { 177 w.WriteHeader(http.StatusOK) 178 } 179 if !w.written { 180 if err := w.writeHeader(); err != nil { 181 return maybeReplaceError(err) 182 } 183 w.written = true 184 } 185 return w.bufferedStr.Flush() 186 } 187 188 func (w *responseWriter) Flush() { 189 if err := w.FlushError(); err != nil { 190 w.logger.Errorf("could not flush to stream: %s", err.Error()) 191 } 192 } 193 194 func (w *responseWriter) StreamCreator() StreamCreator { 195 return w.conn 196 } 197 198 func (w *responseWriter) SetReadDeadline(deadline time.Time) error { 199 return w.str.SetReadDeadline(deadline) 200 } 201 202 func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { 203 return w.str.SetWriteDeadline(deadline) 204 } 205 206 // copied from http2/http2.go 207 // bodyAllowedForStatus reports whether a given response status code 208 // permits a body. See RFC 2616, section 4.4. 209 func bodyAllowedForStatus(status int) bool { 210 switch { 211 case status >= 100 && status <= 199: 212 return false 213 case status == http.StatusNoContent: 214 return false 215 case status == http.StatusNotModified: 216 return false 217 } 218 return true 219 }