golang.org/toolchain@v0.0.1-go1.9rc2.windows-amd64/src/net/http/httptest/recorder_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"testing"
    12  )
    13  
    14  func TestRecorder(t *testing.T) {
    15  	type checkFunc func(*ResponseRecorder) error
    16  	check := func(fns ...checkFunc) []checkFunc { return fns }
    17  
    18  	hasStatus := func(wantCode int) checkFunc {
    19  		return func(rec *ResponseRecorder) error {
    20  			if rec.Code != wantCode {
    21  				return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
    22  			}
    23  			return nil
    24  		}
    25  	}
    26  	hasResultStatus := func(want string) checkFunc {
    27  		return func(rec *ResponseRecorder) error {
    28  			if rec.Result().Status != want {
    29  				return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want)
    30  			}
    31  			return nil
    32  		}
    33  	}
    34  	hasResultStatusCode := func(wantCode int) checkFunc {
    35  		return func(rec *ResponseRecorder) error {
    36  			if rec.Result().StatusCode != wantCode {
    37  				return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
    38  			}
    39  			return nil
    40  		}
    41  	}
    42  	hasContents := func(want string) checkFunc {
    43  		return func(rec *ResponseRecorder) error {
    44  			if rec.Body.String() != want {
    45  				return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
    46  			}
    47  			return nil
    48  		}
    49  	}
    50  	hasFlush := func(want bool) checkFunc {
    51  		return func(rec *ResponseRecorder) error {
    52  			if rec.Flushed != want {
    53  				return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
    54  			}
    55  			return nil
    56  		}
    57  	}
    58  	hasOldHeader := func(key, want string) checkFunc {
    59  		return func(rec *ResponseRecorder) error {
    60  			if got := rec.HeaderMap.Get(key); got != want {
    61  				return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
    62  			}
    63  			return nil
    64  		}
    65  	}
    66  	hasHeader := func(key, want string) checkFunc {
    67  		return func(rec *ResponseRecorder) error {
    68  			if got := rec.Result().Header.Get(key); got != want {
    69  				return fmt.Errorf("final header %s = %q; want %q", key, got, want)
    70  			}
    71  			return nil
    72  		}
    73  	}
    74  	hasNotHeaders := func(keys ...string) checkFunc {
    75  		return func(rec *ResponseRecorder) error {
    76  			for _, k := range keys {
    77  				v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
    78  				if ok {
    79  					return fmt.Errorf("unexpected header %s with value %q", k, v)
    80  				}
    81  			}
    82  			return nil
    83  		}
    84  	}
    85  	hasTrailer := func(key, want string) checkFunc {
    86  		return func(rec *ResponseRecorder) error {
    87  			if got := rec.Result().Trailer.Get(key); got != want {
    88  				return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
    89  			}
    90  			return nil
    91  		}
    92  	}
    93  	hasNotTrailers := func(keys ...string) checkFunc {
    94  		return func(rec *ResponseRecorder) error {
    95  			trailers := rec.Result().Trailer
    96  			for _, k := range keys {
    97  				_, ok := trailers[http.CanonicalHeaderKey(k)]
    98  				if ok {
    99  					return fmt.Errorf("unexpected trailer %s", k)
   100  				}
   101  			}
   102  			return nil
   103  		}
   104  	}
   105  	hasContentLength := func(length int64) checkFunc {
   106  		return func(rec *ResponseRecorder) error {
   107  			if got := rec.Result().ContentLength; got != length {
   108  				return fmt.Errorf("ContentLength = %d; want %d", got, length)
   109  			}
   110  			return nil
   111  		}
   112  	}
   113  
   114  	tests := []struct {
   115  		name   string
   116  		h      func(w http.ResponseWriter, r *http.Request)
   117  		checks []checkFunc
   118  	}{
   119  		{
   120  			"200 default",
   121  			func(w http.ResponseWriter, r *http.Request) {},
   122  			check(hasStatus(200), hasContents("")),
   123  		},
   124  		{
   125  			"first code only",
   126  			func(w http.ResponseWriter, r *http.Request) {
   127  				w.WriteHeader(201)
   128  				w.WriteHeader(202)
   129  				w.Write([]byte("hi"))
   130  			},
   131  			check(hasStatus(201), hasContents("hi")),
   132  		},
   133  		{
   134  			"write sends 200",
   135  			func(w http.ResponseWriter, r *http.Request) {
   136  				w.Write([]byte("hi first"))
   137  				w.WriteHeader(201)
   138  				w.WriteHeader(202)
   139  			},
   140  			check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
   141  		},
   142  		{
   143  			"write string",
   144  			func(w http.ResponseWriter, r *http.Request) {
   145  				io.WriteString(w, "hi first")
   146  			},
   147  			check(
   148  				hasStatus(200),
   149  				hasContents("hi first"),
   150  				hasFlush(false),
   151  				hasHeader("Content-Type", "text/plain; charset=utf-8"),
   152  			),
   153  		},
   154  		{
   155  			"flush",
   156  			func(w http.ResponseWriter, r *http.Request) {
   157  				w.(http.Flusher).Flush() // also sends a 200
   158  				w.WriteHeader(201)
   159  			},
   160  			check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
   161  		},
   162  		{
   163  			"Content-Type detection",
   164  			func(w http.ResponseWriter, r *http.Request) {
   165  				io.WriteString(w, "<html>")
   166  			},
   167  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   168  		},
   169  		{
   170  			"no Content-Type detection with Transfer-Encoding",
   171  			func(w http.ResponseWriter, r *http.Request) {
   172  				w.Header().Set("Transfer-Encoding", "some encoding")
   173  				io.WriteString(w, "<html>")
   174  			},
   175  			check(hasHeader("Content-Type", "")), // no header
   176  		},
   177  		{
   178  			"no Content-Type detection if set explicitly",
   179  			func(w http.ResponseWriter, r *http.Request) {
   180  				w.Header().Set("Content-Type", "some/type")
   181  				io.WriteString(w, "<html>")
   182  			},
   183  			check(hasHeader("Content-Type", "some/type")),
   184  		},
   185  		{
   186  			"Content-Type detection doesn't crash if HeaderMap is nil",
   187  			func(w http.ResponseWriter, r *http.Request) {
   188  				// Act as if the user wrote new(httptest.ResponseRecorder)
   189  				// rather than using NewRecorder (which initializes
   190  				// HeaderMap)
   191  				w.(*ResponseRecorder).HeaderMap = nil
   192  				io.WriteString(w, "<html>")
   193  			},
   194  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   195  		},
   196  		{
   197  			"Header is not changed after write",
   198  			func(w http.ResponseWriter, r *http.Request) {
   199  				hdr := w.Header()
   200  				hdr.Set("Key", "correct")
   201  				w.WriteHeader(200)
   202  				hdr.Set("Key", "incorrect")
   203  			},
   204  			check(hasHeader("Key", "correct")),
   205  		},
   206  		{
   207  			"Trailer headers are correctly recorded",
   208  			func(w http.ResponseWriter, r *http.Request) {
   209  				w.Header().Set("Non-Trailer", "correct")
   210  				w.Header().Set("Trailer", "Trailer-A")
   211  				w.Header().Add("Trailer", "Trailer-B")
   212  				w.Header().Add("Trailer", "Trailer-C")
   213  				io.WriteString(w, "<html>")
   214  				w.Header().Set("Non-Trailer", "incorrect")
   215  				w.Header().Set("Trailer-A", "valuea")
   216  				w.Header().Set("Trailer-C", "valuec")
   217  				w.Header().Set("Trailer-NotDeclared", "should be omitted")
   218  				w.Header().Set("Trailer:Trailer-D", "with prefix")
   219  			},
   220  			check(
   221  				hasStatus(200),
   222  				hasHeader("Content-Type", "text/html; charset=utf-8"),
   223  				hasHeader("Non-Trailer", "correct"),
   224  				hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
   225  				hasTrailer("Trailer-A", "valuea"),
   226  				hasTrailer("Trailer-C", "valuec"),
   227  				hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
   228  				hasTrailer("Trailer-D", "with prefix"),
   229  			),
   230  		},
   231  		{
   232  			"Header set without any write", // Issue 15560
   233  			func(w http.ResponseWriter, r *http.Request) {
   234  				w.Header().Set("X-Foo", "1")
   235  
   236  				// Simulate somebody using
   237  				// new(ResponseRecorder) instead of
   238  				// using the constructor which sets
   239  				// this to 200
   240  				w.(*ResponseRecorder).Code = 0
   241  			},
   242  			check(
   243  				hasOldHeader("X-Foo", "1"),
   244  				hasStatus(0),
   245  				hasHeader("X-Foo", "1"),
   246  				hasResultStatus("200 OK"),
   247  				hasResultStatusCode(200),
   248  			),
   249  		},
   250  		{
   251  			"HeaderMap vs FinalHeaders", // more for Issue 15560
   252  			func(w http.ResponseWriter, r *http.Request) {
   253  				h := w.Header()
   254  				h.Set("X-Foo", "1")
   255  				w.Write([]byte("hi"))
   256  				h.Set("X-Foo", "2")
   257  				h.Set("X-Bar", "2")
   258  			},
   259  			check(
   260  				hasOldHeader("X-Foo", "2"),
   261  				hasOldHeader("X-Bar", "2"),
   262  				hasHeader("X-Foo", "1"),
   263  				hasNotHeaders("X-Bar"),
   264  			),
   265  		},
   266  		{
   267  			"setting Content-Length header",
   268  			func(w http.ResponseWriter, r *http.Request) {
   269  				body := "Some body"
   270  				contentLength := fmt.Sprintf("%d", len(body))
   271  				w.Header().Set("Content-Length", contentLength)
   272  				io.WriteString(w, body)
   273  			},
   274  			check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
   275  		},
   276  	}
   277  	r, _ := http.NewRequest("GET", "http://foo.com/", nil)
   278  	for _, tt := range tests {
   279  		h := http.HandlerFunc(tt.h)
   280  		rec := NewRecorder()
   281  		h.ServeHTTP(rec, r)
   282  		for _, check := range tt.checks {
   283  			if err := check(rec); err != nil {
   284  				t.Errorf("%s: %v", tt.name, err)
   285  			}
   286  		}
   287  	}
   288  }