github.com/mattn/go@v0.0.0-20171011075504-07f7db3ea99f/src/net/http/httptest/recorder.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package httptest 6 7 import ( 8 "bytes" 9 "fmt" 10 "io/ioutil" 11 "net/http" 12 "strconv" 13 "strings" 14 ) 15 16 // ResponseRecorder is an implementation of http.ResponseWriter that 17 // records its mutations for later inspection in tests. 18 type ResponseRecorder struct { 19 // Code is the HTTP response code set by WriteHeader. 20 // 21 // Note that if a Handler never calls WriteHeader or Write, 22 // this might end up being 0, rather than the implicit 23 // http.StatusOK. To get the implicit value, use the Result 24 // method. 25 Code int 26 27 // HeaderMap contains the headers explicitly set by the Handler. 28 // 29 // To get the implicit headers set by the server (such as 30 // automatic Content-Type), use the Result method. 31 HeaderMap http.Header 32 33 // Body is the buffer to which the Handler's Write calls are sent. 34 // If nil, the Writes are silently discarded. 35 Body *bytes.Buffer 36 37 // Flushed is whether the Handler called Flush. 38 Flushed bool 39 40 result *http.Response // cache of Result's return value 41 snapHeader http.Header // snapshot of HeaderMap at first Write 42 wroteHeader bool 43 } 44 45 // NewRecorder returns an initialized ResponseRecorder. 46 func NewRecorder() *ResponseRecorder { 47 return &ResponseRecorder{ 48 HeaderMap: make(http.Header), 49 Body: new(bytes.Buffer), 50 Code: 200, 51 } 52 } 53 54 // DefaultRemoteAddr is the default remote address to return in RemoteAddr if 55 // an explicit DefaultRemoteAddr isn't set on ResponseRecorder. 56 const DefaultRemoteAddr = "1.2.3.4" 57 58 // Header returns the response headers. 59 func (rw *ResponseRecorder) Header() http.Header { 60 m := rw.HeaderMap 61 if m == nil { 62 m = make(http.Header) 63 rw.HeaderMap = m 64 } 65 return m 66 } 67 68 // writeHeader writes a header if it was not written yet and 69 // detects Content-Type if needed. 70 // 71 // bytes or str are the beginning of the response body. 72 // We pass both to avoid unnecessarily generate garbage 73 // in rw.WriteString which was created for performance reasons. 74 // Non-nil bytes win. 75 func (rw *ResponseRecorder) writeHeader(b []byte, str string) { 76 if rw.wroteHeader { 77 return 78 } 79 if len(str) > 512 { 80 str = str[:512] 81 } 82 83 m := rw.Header() 84 85 _, hasType := m["Content-Type"] 86 hasTE := m.Get("Transfer-Encoding") != "" 87 if !hasType && !hasTE { 88 if b == nil { 89 b = []byte(str) 90 } 91 m.Set("Content-Type", http.DetectContentType(b)) 92 } 93 94 rw.WriteHeader(200) 95 } 96 97 // Write always succeeds and writes to rw.Body, if not nil. 98 func (rw *ResponseRecorder) Write(buf []byte) (int, error) { 99 rw.writeHeader(buf, "") 100 if rw.Body != nil { 101 rw.Body.Write(buf) 102 } 103 return len(buf), nil 104 } 105 106 // WriteString always succeeds and writes to rw.Body, if not nil. 107 func (rw *ResponseRecorder) WriteString(str string) (int, error) { 108 rw.writeHeader(nil, str) 109 if rw.Body != nil { 110 rw.Body.WriteString(str) 111 } 112 return len(str), nil 113 } 114 115 // WriteHeader sets rw.Code. After it is called, changing rw.Header 116 // will not affect rw.HeaderMap. 117 func (rw *ResponseRecorder) WriteHeader(code int) { 118 if rw.wroteHeader { 119 return 120 } 121 rw.Code = code 122 rw.wroteHeader = true 123 if rw.HeaderMap == nil { 124 rw.HeaderMap = make(http.Header) 125 } 126 rw.snapHeader = cloneHeader(rw.HeaderMap) 127 } 128 129 func cloneHeader(h http.Header) http.Header { 130 h2 := make(http.Header, len(h)) 131 for k, vv := range h { 132 vv2 := make([]string, len(vv)) 133 copy(vv2, vv) 134 h2[k] = vv2 135 } 136 return h2 137 } 138 139 // Flush sets rw.Flushed to true. 140 func (rw *ResponseRecorder) Flush() { 141 if !rw.wroteHeader { 142 rw.WriteHeader(200) 143 } 144 rw.Flushed = true 145 } 146 147 // Result returns the response generated by the handler. 148 // 149 // The returned Response will have at least its StatusCode, 150 // Header, Body, and optionally Trailer populated. 151 // More fields may be populated in the future, so callers should 152 // not DeepEqual the result in tests. 153 // 154 // The Response.Header is a snapshot of the headers at the time of the 155 // first write call, or at the time of this call, if the handler never 156 // did a write. 157 // 158 // The Response.Body is guaranteed to be non-nil and Body.Read call is 159 // guaranteed to not return any error other than io.EOF. 160 // 161 // Result must only be called after the handler has finished running. 162 func (rw *ResponseRecorder) Result() *http.Response { 163 if rw.result != nil { 164 return rw.result 165 } 166 if rw.snapHeader == nil { 167 rw.snapHeader = cloneHeader(rw.HeaderMap) 168 } 169 res := &http.Response{ 170 Proto: "HTTP/1.1", 171 ProtoMajor: 1, 172 ProtoMinor: 1, 173 StatusCode: rw.Code, 174 Header: rw.snapHeader, 175 } 176 rw.result = res 177 if res.StatusCode == 0 { 178 res.StatusCode = 200 179 } 180 res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) 181 if rw.Body != nil { 182 res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes())) 183 } 184 res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) 185 186 if trailers, ok := rw.snapHeader["Trailer"]; ok { 187 res.Trailer = make(http.Header, len(trailers)) 188 for _, k := range trailers { 189 // TODO: use http2.ValidTrailerHeader, but we can't 190 // get at it easily because it's bundled into net/http 191 // unexported. This is good enough for now: 192 switch k { 193 case "Transfer-Encoding", "Content-Length", "Trailer": 194 // Ignore since forbidden by RFC 2616 14.40. 195 continue 196 } 197 k = http.CanonicalHeaderKey(k) 198 vv, ok := rw.HeaderMap[k] 199 if !ok { 200 continue 201 } 202 vv2 := make([]string, len(vv)) 203 copy(vv2, vv) 204 res.Trailer[k] = vv2 205 } 206 } 207 for k, vv := range rw.HeaderMap { 208 if !strings.HasPrefix(k, http.TrailerPrefix) { 209 continue 210 } 211 if res.Trailer == nil { 212 res.Trailer = make(http.Header) 213 } 214 for _, v := range vv { 215 res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v) 216 } 217 } 218 return res 219 } 220 221 // parseContentLength trims whitespace from s and returns -1 if no value 222 // is set, or the value if it's >= 0. 223 // 224 // This a modified version of same function found in net/http/transfer.go. This 225 // one just ignores an invalid header. 226 func parseContentLength(cl string) int64 { 227 cl = strings.TrimSpace(cl) 228 if cl == "" { 229 return -1 230 } 231 n, err := strconv.ParseInt(cl, 10, 64) 232 if err != nil { 233 return -1 234 } 235 return n 236 }