github.com/geraldss/go/src@v0.0.0-20210511222824-ac7d0ebfc235/net/http/httptest/recorder_test.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package httptest 6 7 import ( 8 "fmt" 9 "io" 10 "net/http" 11 "testing" 12 ) 13 14 func TestRecorder(t *testing.T) { 15 type checkFunc func(*ResponseRecorder) error 16 check := func(fns ...checkFunc) []checkFunc { return fns } 17 18 hasStatus := func(wantCode int) checkFunc { 19 return func(rec *ResponseRecorder) error { 20 if rec.Code != wantCode { 21 return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode) 22 } 23 return nil 24 } 25 } 26 hasResultStatus := func(want string) checkFunc { 27 return func(rec *ResponseRecorder) error { 28 if rec.Result().Status != want { 29 return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want) 30 } 31 return nil 32 } 33 } 34 hasResultStatusCode := func(wantCode int) checkFunc { 35 return func(rec *ResponseRecorder) error { 36 if rec.Result().StatusCode != wantCode { 37 return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode) 38 } 39 return nil 40 } 41 } 42 hasResultContents := func(want string) checkFunc { 43 return func(rec *ResponseRecorder) error { 44 contentBytes, err := io.ReadAll(rec.Result().Body) 45 if err != nil { 46 return err 47 } 48 contents := string(contentBytes) 49 if contents != want { 50 return fmt.Errorf("Result().Body = %s; want %s", contents, want) 51 } 52 return nil 53 } 54 } 55 hasContents := func(want string) checkFunc { 56 return func(rec *ResponseRecorder) error { 57 if rec.Body.String() != want { 58 return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want) 59 } 60 return nil 61 } 62 } 63 hasFlush := func(want bool) checkFunc { 64 return func(rec *ResponseRecorder) error { 65 if rec.Flushed != want { 66 return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want) 67 } 68 return nil 69 } 70 } 71 hasOldHeader := func(key, want string) checkFunc { 72 return func(rec *ResponseRecorder) error { 73 if got := rec.HeaderMap.Get(key); got != want { 74 return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want) 75 } 76 return nil 77 } 78 } 79 hasHeader := func(key, want string) checkFunc { 80 return func(rec *ResponseRecorder) error { 81 if got := rec.Result().Header.Get(key); got != want { 82 return fmt.Errorf("final header %s = %q; want %q", key, got, want) 83 } 84 return nil 85 } 86 } 87 hasNotHeaders := func(keys ...string) checkFunc { 88 return func(rec *ResponseRecorder) error { 89 for _, k := range keys { 90 v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)] 91 if ok { 92 return fmt.Errorf("unexpected header %s with value %q", k, v) 93 } 94 } 95 return nil 96 } 97 } 98 hasTrailer := func(key, want string) checkFunc { 99 return func(rec *ResponseRecorder) error { 100 if got := rec.Result().Trailer.Get(key); got != want { 101 return fmt.Errorf("trailer %s = %q; want %q", key, got, want) 102 } 103 return nil 104 } 105 } 106 hasNotTrailers := func(keys ...string) checkFunc { 107 return func(rec *ResponseRecorder) error { 108 trailers := rec.Result().Trailer 109 for _, k := range keys { 110 _, ok := trailers[http.CanonicalHeaderKey(k)] 111 if ok { 112 return fmt.Errorf("unexpected trailer %s", k) 113 } 114 } 115 return nil 116 } 117 } 118 hasContentLength := func(length int64) checkFunc { 119 return func(rec *ResponseRecorder) error { 120 if got := rec.Result().ContentLength; got != length { 121 return fmt.Errorf("ContentLength = %d; want %d", got, length) 122 } 123 return nil 124 } 125 } 126 127 for _, tt := range [...]struct { 128 name string 129 h func(w http.ResponseWriter, r *http.Request) 130 checks []checkFunc 131 }{ 132 { 133 "200 default", 134 func(w http.ResponseWriter, r *http.Request) {}, 135 check(hasStatus(200), hasContents("")), 136 }, 137 { 138 "first code only", 139 func(w http.ResponseWriter, r *http.Request) { 140 w.WriteHeader(201) 141 w.WriteHeader(202) 142 w.Write([]byte("hi")) 143 }, 144 check(hasStatus(201), hasContents("hi")), 145 }, 146 { 147 "write sends 200", 148 func(w http.ResponseWriter, r *http.Request) { 149 w.Write([]byte("hi first")) 150 w.WriteHeader(201) 151 w.WriteHeader(202) 152 }, 153 check(hasStatus(200), hasContents("hi first"), hasFlush(false)), 154 }, 155 { 156 "write string", 157 func(w http.ResponseWriter, r *http.Request) { 158 io.WriteString(w, "hi first") 159 }, 160 check( 161 hasStatus(200), 162 hasContents("hi first"), 163 hasFlush(false), 164 hasHeader("Content-Type", "text/plain; charset=utf-8"), 165 ), 166 }, 167 { 168 "flush", 169 func(w http.ResponseWriter, r *http.Request) { 170 w.(http.Flusher).Flush() // also sends a 200 171 w.WriteHeader(201) 172 }, 173 check(hasStatus(200), hasFlush(true), hasContentLength(-1)), 174 }, 175 { 176 "Content-Type detection", 177 func(w http.ResponseWriter, r *http.Request) { 178 io.WriteString(w, "<html>") 179 }, 180 check(hasHeader("Content-Type", "text/html; charset=utf-8")), 181 }, 182 { 183 "no Content-Type detection with Transfer-Encoding", 184 func(w http.ResponseWriter, r *http.Request) { 185 w.Header().Set("Transfer-Encoding", "some encoding") 186 io.WriteString(w, "<html>") 187 }, 188 check(hasHeader("Content-Type", "")), // no header 189 }, 190 { 191 "no Content-Type detection if set explicitly", 192 func(w http.ResponseWriter, r *http.Request) { 193 w.Header().Set("Content-Type", "some/type") 194 io.WriteString(w, "<html>") 195 }, 196 check(hasHeader("Content-Type", "some/type")), 197 }, 198 { 199 "Content-Type detection doesn't crash if HeaderMap is nil", 200 func(w http.ResponseWriter, r *http.Request) { 201 // Act as if the user wrote new(httptest.ResponseRecorder) 202 // rather than using NewRecorder (which initializes 203 // HeaderMap) 204 w.(*ResponseRecorder).HeaderMap = nil 205 io.WriteString(w, "<html>") 206 }, 207 check(hasHeader("Content-Type", "text/html; charset=utf-8")), 208 }, 209 { 210 "Header is not changed after write", 211 func(w http.ResponseWriter, r *http.Request) { 212 hdr := w.Header() 213 hdr.Set("Key", "correct") 214 w.WriteHeader(200) 215 hdr.Set("Key", "incorrect") 216 }, 217 check(hasHeader("Key", "correct")), 218 }, 219 { 220 "Trailer headers are correctly recorded", 221 func(w http.ResponseWriter, r *http.Request) { 222 w.Header().Set("Non-Trailer", "correct") 223 w.Header().Set("Trailer", "Trailer-A") 224 w.Header().Add("Trailer", "Trailer-B") 225 w.Header().Add("Trailer", "Trailer-C") 226 io.WriteString(w, "<html>") 227 w.Header().Set("Non-Trailer", "incorrect") 228 w.Header().Set("Trailer-A", "valuea") 229 w.Header().Set("Trailer-C", "valuec") 230 w.Header().Set("Trailer-NotDeclared", "should be omitted") 231 w.Header().Set("Trailer:Trailer-D", "with prefix") 232 }, 233 check( 234 hasStatus(200), 235 hasHeader("Content-Type", "text/html; charset=utf-8"), 236 hasHeader("Non-Trailer", "correct"), 237 hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"), 238 hasTrailer("Trailer-A", "valuea"), 239 hasTrailer("Trailer-C", "valuec"), 240 hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"), 241 hasTrailer("Trailer-D", "with prefix"), 242 ), 243 }, 244 { 245 "Header set without any write", // Issue 15560 246 func(w http.ResponseWriter, r *http.Request) { 247 w.Header().Set("X-Foo", "1") 248 249 // Simulate somebody using 250 // new(ResponseRecorder) instead of 251 // using the constructor which sets 252 // this to 200 253 w.(*ResponseRecorder).Code = 0 254 }, 255 check( 256 hasOldHeader("X-Foo", "1"), 257 hasStatus(0), 258 hasHeader("X-Foo", "1"), 259 hasResultStatus("200 OK"), 260 hasResultStatusCode(200), 261 ), 262 }, 263 { 264 "HeaderMap vs FinalHeaders", // more for Issue 15560 265 func(w http.ResponseWriter, r *http.Request) { 266 h := w.Header() 267 h.Set("X-Foo", "1") 268 w.Write([]byte("hi")) 269 h.Set("X-Foo", "2") 270 h.Set("X-Bar", "2") 271 }, 272 check( 273 hasOldHeader("X-Foo", "2"), 274 hasOldHeader("X-Bar", "2"), 275 hasHeader("X-Foo", "1"), 276 hasNotHeaders("X-Bar"), 277 ), 278 }, 279 { 280 "setting Content-Length header", 281 func(w http.ResponseWriter, r *http.Request) { 282 body := "Some body" 283 contentLength := fmt.Sprintf("%d", len(body)) 284 w.Header().Set("Content-Length", contentLength) 285 io.WriteString(w, body) 286 }, 287 check(hasStatus(200), hasContents("Some body"), hasContentLength(9)), 288 }, 289 { 290 "nil ResponseRecorder.Body", // Issue 26642 291 func(w http.ResponseWriter, r *http.Request) { 292 w.(*ResponseRecorder).Body = nil 293 io.WriteString(w, "hi") 294 }, 295 check(hasResultContents("")), // check we don't crash reading the body 296 297 }, 298 } { 299 t.Run(tt.name, func(t *testing.T) { 300 r, _ := http.NewRequest("GET", "http://foo.com/", nil) 301 h := http.HandlerFunc(tt.h) 302 rec := NewRecorder() 303 h.ServeHTTP(rec, r) 304 for _, check := range tt.checks { 305 if err := check(rec); err != nil { 306 t.Error(err) 307 } 308 } 309 }) 310 } 311 } 312 313 // issue 39017 - disallow Content-Length values such as "+3" 314 func TestParseContentLength(t *testing.T) { 315 tests := []struct { 316 cl string 317 want int64 318 }{ 319 { 320 cl: "3", 321 want: 3, 322 }, 323 { 324 cl: "+3", 325 want: -1, 326 }, 327 { 328 cl: "-3", 329 want: -1, 330 }, 331 { 332 // max int64, for safe conversion before returning 333 cl: "9223372036854775807", 334 want: 9223372036854775807, 335 }, 336 { 337 cl: "9223372036854775808", 338 want: -1, 339 }, 340 } 341 342 for _, tt := range tests { 343 if got := parseContentLength(tt.cl); got != tt.want { 344 t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want) 345 } 346 } 347 }