github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/http3/response_writer.go (about) 1 package http3 2 3 import ( 4 "bufio" 5 "bytes" 6 "net/http" 7 "strconv" 8 "strings" 9 "time" 10 11 "github.com/mikelsr/quic-go" 12 "github.com/mikelsr/quic-go/internal/utils" 13 14 "github.com/quic-go/qpack" 15 ) 16 17 type responseWriter struct { 18 conn quic.Connection 19 str quic.Stream 20 bufferedStr *bufio.Writer 21 buf []byte 22 23 header http.Header 24 status int // status code passed to WriteHeader 25 headerWritten bool 26 27 logger utils.Logger 28 } 29 30 var ( 31 _ http.ResponseWriter = &responseWriter{} 32 _ http.Flusher = &responseWriter{} 33 _ Hijacker = &responseWriter{} 34 ) 35 36 func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { 37 return &responseWriter{ 38 header: http.Header{}, 39 buf: make([]byte, 16), 40 conn: conn, 41 str: str, 42 bufferedStr: bufio.NewWriter(str), 43 logger: logger, 44 } 45 } 46 47 func (w *responseWriter) Header() http.Header { 48 return w.header 49 } 50 51 func (w *responseWriter) WriteHeader(status int) { 52 if w.headerWritten { 53 return 54 } 55 56 if status < 100 || status >= 200 { 57 w.headerWritten = true 58 } 59 w.status = status 60 61 var headers bytes.Buffer 62 enc := qpack.NewEncoder(&headers) 63 enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) 64 65 for k, v := range w.header { 66 for index := range v { 67 enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) 68 } 69 } 70 71 w.buf = w.buf[:0] 72 w.buf = (&headersFrame{Length: uint64(headers.Len())}).Append(w.buf) 73 w.logger.Infof("Responding with %d", status) 74 if _, err := w.bufferedStr.Write(w.buf); err != nil { 75 w.logger.Errorf("could not write headers frame: %s", err.Error()) 76 } 77 if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil { 78 w.logger.Errorf("could not write header frame payload: %s", err.Error()) 79 } 80 if !w.headerWritten { 81 w.Flush() 82 } 83 } 84 85 func (w *responseWriter) Write(p []byte) (int, error) { 86 bodyAllowed := bodyAllowedForStatus(w.status) 87 if !w.headerWritten { 88 // If body is not allowed, we don't need to (and we can't) sniff the content type. 89 if bodyAllowed { 90 // If no content type, apply sniffing algorithm to body. 91 // We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing. 92 _, haveType := w.header["Content-Type"] 93 94 // If the Transfer-Encoding or Content-Encoding was set and is non-blank, 95 // we shouldn't sniff the body. 96 hasTE := w.header.Get("Transfer-Encoding") != "" 97 hasCE := w.header.Get("Content-Encoding") != "" 98 if !hasCE && !haveType && !hasTE && len(p) > 0 { 99 w.header.Set("Content-Type", http.DetectContentType(p)) 100 } 101 } 102 w.WriteHeader(http.StatusOK) 103 bodyAllowed = true 104 } 105 if !bodyAllowed { 106 return 0, http.ErrBodyNotAllowed 107 } 108 df := &dataFrame{Length: uint64(len(p))} 109 w.buf = w.buf[:0] 110 w.buf = df.Append(w.buf) 111 if _, err := w.bufferedStr.Write(w.buf); err != nil { 112 return 0, err 113 } 114 return w.bufferedStr.Write(p) 115 } 116 117 func (w *responseWriter) Flush() { 118 if err := w.bufferedStr.Flush(); err != nil { 119 w.logger.Errorf("could not flush to stream: %s", err.Error()) 120 } 121 } 122 123 func (w *responseWriter) StreamCreator() StreamCreator { 124 return w.conn 125 } 126 127 func (w *responseWriter) SetReadDeadline(deadline time.Time) error { 128 return w.str.SetReadDeadline(deadline) 129 } 130 131 func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { 132 return w.str.SetWriteDeadline(deadline) 133 } 134 135 // copied from http2/http2.go 136 // bodyAllowedForStatus reports whether a given response status code 137 // permits a body. See RFC 2616, section 4.4. 138 func bodyAllowedForStatus(status int) bool { 139 switch { 140 case status >= 100 && status <= 199: 141 return false 142 case status == http.StatusNoContent: 143 return false 144 case status == http.StatusNotModified: 145 return false 146 } 147 return true 148 }