golang.org/toolchain@v0.0.1-go1.9rc2.windows-amd64/src/net/http/httptest/recorder_test.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package httptest 6 7 import ( 8 "fmt" 9 "io" 10 "net/http" 11 "testing" 12 ) 13 14 func TestRecorder(t *testing.T) { 15 type checkFunc func(*ResponseRecorder) error 16 check := func(fns ...checkFunc) []checkFunc { return fns } 17 18 hasStatus := func(wantCode int) checkFunc { 19 return func(rec *ResponseRecorder) error { 20 if rec.Code != wantCode { 21 return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode) 22 } 23 return nil 24 } 25 } 26 hasResultStatus := func(want string) checkFunc { 27 return func(rec *ResponseRecorder) error { 28 if rec.Result().Status != want { 29 return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want) 30 } 31 return nil 32 } 33 } 34 hasResultStatusCode := func(wantCode int) checkFunc { 35 return func(rec *ResponseRecorder) error { 36 if rec.Result().StatusCode != wantCode { 37 return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode) 38 } 39 return nil 40 } 41 } 42 hasContents := func(want string) checkFunc { 43 return func(rec *ResponseRecorder) error { 44 if rec.Body.String() != want { 45 return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want) 46 } 47 return nil 48 } 49 } 50 hasFlush := func(want bool) checkFunc { 51 return func(rec *ResponseRecorder) error { 52 if rec.Flushed != want { 53 return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want) 54 } 55 return nil 56 } 57 } 58 hasOldHeader := func(key, want string) checkFunc { 59 return func(rec *ResponseRecorder) error { 60 if got := rec.HeaderMap.Get(key); got != want { 61 return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want) 62 } 63 return nil 64 } 65 } 66 hasHeader := func(key, want string) checkFunc { 67 return func(rec *ResponseRecorder) error { 68 if got := rec.Result().Header.Get(key); got != want { 69 return fmt.Errorf("final header %s = %q; want %q", key, got, want) 70 } 71 return nil 72 } 73 } 74 hasNotHeaders := func(keys ...string) checkFunc { 75 return func(rec *ResponseRecorder) error { 76 for _, k := range keys { 77 v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)] 78 if ok { 79 return fmt.Errorf("unexpected header %s with value %q", k, v) 80 } 81 } 82 return nil 83 } 84 } 85 hasTrailer := func(key, want string) checkFunc { 86 return func(rec *ResponseRecorder) error { 87 if got := rec.Result().Trailer.Get(key); got != want { 88 return fmt.Errorf("trailer %s = %q; want %q", key, got, want) 89 } 90 return nil 91 } 92 } 93 hasNotTrailers := func(keys ...string) checkFunc { 94 return func(rec *ResponseRecorder) error { 95 trailers := rec.Result().Trailer 96 for _, k := range keys { 97 _, ok := trailers[http.CanonicalHeaderKey(k)] 98 if ok { 99 return fmt.Errorf("unexpected trailer %s", k) 100 } 101 } 102 return nil 103 } 104 } 105 hasContentLength := func(length int64) checkFunc { 106 return func(rec *ResponseRecorder) error { 107 if got := rec.Result().ContentLength; got != length { 108 return fmt.Errorf("ContentLength = %d; want %d", got, length) 109 } 110 return nil 111 } 112 } 113 114 tests := []struct { 115 name string 116 h func(w http.ResponseWriter, r *http.Request) 117 checks []checkFunc 118 }{ 119 { 120 "200 default", 121 func(w http.ResponseWriter, r *http.Request) {}, 122 check(hasStatus(200), hasContents("")), 123 }, 124 { 125 "first code only", 126 func(w http.ResponseWriter, r *http.Request) { 127 w.WriteHeader(201) 128 w.WriteHeader(202) 129 w.Write([]byte("hi")) 130 }, 131 check(hasStatus(201), hasContents("hi")), 132 }, 133 { 134 "write sends 200", 135 func(w http.ResponseWriter, r *http.Request) { 136 w.Write([]byte("hi first")) 137 w.WriteHeader(201) 138 w.WriteHeader(202) 139 }, 140 check(hasStatus(200), hasContents("hi first"), hasFlush(false)), 141 }, 142 { 143 "write string", 144 func(w http.ResponseWriter, r *http.Request) { 145 io.WriteString(w, "hi first") 146 }, 147 check( 148 hasStatus(200), 149 hasContents("hi first"), 150 hasFlush(false), 151 hasHeader("Content-Type", "text/plain; charset=utf-8"), 152 ), 153 }, 154 { 155 "flush", 156 func(w http.ResponseWriter, r *http.Request) { 157 w.(http.Flusher).Flush() // also sends a 200 158 w.WriteHeader(201) 159 }, 160 check(hasStatus(200), hasFlush(true), hasContentLength(-1)), 161 }, 162 { 163 "Content-Type detection", 164 func(w http.ResponseWriter, r *http.Request) { 165 io.WriteString(w, "<html>") 166 }, 167 check(hasHeader("Content-Type", "text/html; charset=utf-8")), 168 }, 169 { 170 "no Content-Type detection with Transfer-Encoding", 171 func(w http.ResponseWriter, r *http.Request) { 172 w.Header().Set("Transfer-Encoding", "some encoding") 173 io.WriteString(w, "<html>") 174 }, 175 check(hasHeader("Content-Type", "")), // no header 176 }, 177 { 178 "no Content-Type detection if set explicitly", 179 func(w http.ResponseWriter, r *http.Request) { 180 w.Header().Set("Content-Type", "some/type") 181 io.WriteString(w, "<html>") 182 }, 183 check(hasHeader("Content-Type", "some/type")), 184 }, 185 { 186 "Content-Type detection doesn't crash if HeaderMap is nil", 187 func(w http.ResponseWriter, r *http.Request) { 188 // Act as if the user wrote new(httptest.ResponseRecorder) 189 // rather than using NewRecorder (which initializes 190 // HeaderMap) 191 w.(*ResponseRecorder).HeaderMap = nil 192 io.WriteString(w, "<html>") 193 }, 194 check(hasHeader("Content-Type", "text/html; charset=utf-8")), 195 }, 196 { 197 "Header is not changed after write", 198 func(w http.ResponseWriter, r *http.Request) { 199 hdr := w.Header() 200 hdr.Set("Key", "correct") 201 w.WriteHeader(200) 202 hdr.Set("Key", "incorrect") 203 }, 204 check(hasHeader("Key", "correct")), 205 }, 206 { 207 "Trailer headers are correctly recorded", 208 func(w http.ResponseWriter, r *http.Request) { 209 w.Header().Set("Non-Trailer", "correct") 210 w.Header().Set("Trailer", "Trailer-A") 211 w.Header().Add("Trailer", "Trailer-B") 212 w.Header().Add("Trailer", "Trailer-C") 213 io.WriteString(w, "<html>") 214 w.Header().Set("Non-Trailer", "incorrect") 215 w.Header().Set("Trailer-A", "valuea") 216 w.Header().Set("Trailer-C", "valuec") 217 w.Header().Set("Trailer-NotDeclared", "should be omitted") 218 w.Header().Set("Trailer:Trailer-D", "with prefix") 219 }, 220 check( 221 hasStatus(200), 222 hasHeader("Content-Type", "text/html; charset=utf-8"), 223 hasHeader("Non-Trailer", "correct"), 224 hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"), 225 hasTrailer("Trailer-A", "valuea"), 226 hasTrailer("Trailer-C", "valuec"), 227 hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"), 228 hasTrailer("Trailer-D", "with prefix"), 229 ), 230 }, 231 { 232 "Header set without any write", // Issue 15560 233 func(w http.ResponseWriter, r *http.Request) { 234 w.Header().Set("X-Foo", "1") 235 236 // Simulate somebody using 237 // new(ResponseRecorder) instead of 238 // using the constructor which sets 239 // this to 200 240 w.(*ResponseRecorder).Code = 0 241 }, 242 check( 243 hasOldHeader("X-Foo", "1"), 244 hasStatus(0), 245 hasHeader("X-Foo", "1"), 246 hasResultStatus("200 OK"), 247 hasResultStatusCode(200), 248 ), 249 }, 250 { 251 "HeaderMap vs FinalHeaders", // more for Issue 15560 252 func(w http.ResponseWriter, r *http.Request) { 253 h := w.Header() 254 h.Set("X-Foo", "1") 255 w.Write([]byte("hi")) 256 h.Set("X-Foo", "2") 257 h.Set("X-Bar", "2") 258 }, 259 check( 260 hasOldHeader("X-Foo", "2"), 261 hasOldHeader("X-Bar", "2"), 262 hasHeader("X-Foo", "1"), 263 hasNotHeaders("X-Bar"), 264 ), 265 }, 266 { 267 "setting Content-Length header", 268 func(w http.ResponseWriter, r *http.Request) { 269 body := "Some body" 270 contentLength := fmt.Sprintf("%d", len(body)) 271 w.Header().Set("Content-Length", contentLength) 272 io.WriteString(w, body) 273 }, 274 check(hasStatus(200), hasContents("Some body"), hasContentLength(9)), 275 }, 276 } 277 r, _ := http.NewRequest("GET", "http://foo.com/", nil) 278 for _, tt := range tests { 279 h := http.HandlerFunc(tt.h) 280 rec := NewRecorder() 281 h.ServeHTTP(rec, r) 282 for _, check := range tt.checks { 283 if err := check(rec); err != nil { 284 t.Errorf("%s: %v", tt.name, err) 285 } 286 } 287 } 288 }