github.com/slayercat/go@v0.0.0-20170428012452-c51559813f61/src/net/http/httptest/recorder_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"testing"
    12  )
    13  
    14  func TestRecorder(t *testing.T) {
    15  	type checkFunc func(*ResponseRecorder) error
    16  	check := func(fns ...checkFunc) []checkFunc { return fns }
    17  
    18  	hasStatus := func(wantCode int) checkFunc {
    19  		return func(rec *ResponseRecorder) error {
    20  			if rec.Code != wantCode {
    21  				return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
    22  			}
    23  			return nil
    24  		}
    25  	}
    26  	hasResultStatus := func(wantCode int) checkFunc {
    27  		return func(rec *ResponseRecorder) error {
    28  			if rec.Result().StatusCode != wantCode {
    29  				return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
    30  			}
    31  			return nil
    32  		}
    33  	}
    34  	hasContents := func(want string) checkFunc {
    35  		return func(rec *ResponseRecorder) error {
    36  			if rec.Body.String() != want {
    37  				return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
    38  			}
    39  			return nil
    40  		}
    41  	}
    42  	hasFlush := func(want bool) checkFunc {
    43  		return func(rec *ResponseRecorder) error {
    44  			if rec.Flushed != want {
    45  				return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
    46  			}
    47  			return nil
    48  		}
    49  	}
    50  	hasOldHeader := func(key, want string) checkFunc {
    51  		return func(rec *ResponseRecorder) error {
    52  			if got := rec.HeaderMap.Get(key); got != want {
    53  				return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
    54  			}
    55  			return nil
    56  		}
    57  	}
    58  	hasHeader := func(key, want string) checkFunc {
    59  		return func(rec *ResponseRecorder) error {
    60  			if got := rec.Result().Header.Get(key); got != want {
    61  				return fmt.Errorf("final header %s = %q; want %q", key, got, want)
    62  			}
    63  			return nil
    64  		}
    65  	}
    66  	hasNotHeaders := func(keys ...string) checkFunc {
    67  		return func(rec *ResponseRecorder) error {
    68  			for _, k := range keys {
    69  				v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
    70  				if ok {
    71  					return fmt.Errorf("unexpected header %s with value %q", k, v)
    72  				}
    73  			}
    74  			return nil
    75  		}
    76  	}
    77  	hasTrailer := func(key, want string) checkFunc {
    78  		return func(rec *ResponseRecorder) error {
    79  			if got := rec.Result().Trailer.Get(key); got != want {
    80  				return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
    81  			}
    82  			return nil
    83  		}
    84  	}
    85  	hasNotTrailers := func(keys ...string) checkFunc {
    86  		return func(rec *ResponseRecorder) error {
    87  			trailers := rec.Result().Trailer
    88  			for _, k := range keys {
    89  				_, ok := trailers[http.CanonicalHeaderKey(k)]
    90  				if ok {
    91  					return fmt.Errorf("unexpected trailer %s", k)
    92  				}
    93  			}
    94  			return nil
    95  		}
    96  	}
    97  	hasContentLength := func(length int64) checkFunc {
    98  		return func(rec *ResponseRecorder) error {
    99  			if got := rec.Result().ContentLength; got != length {
   100  				return fmt.Errorf("ContentLength = %d; want %d", got, length)
   101  			}
   102  			return nil
   103  		}
   104  	}
   105  
   106  	tests := []struct {
   107  		name   string
   108  		h      func(w http.ResponseWriter, r *http.Request)
   109  		checks []checkFunc
   110  	}{
   111  		{
   112  			"200 default",
   113  			func(w http.ResponseWriter, r *http.Request) {},
   114  			check(hasStatus(200), hasContents("")),
   115  		},
   116  		{
   117  			"first code only",
   118  			func(w http.ResponseWriter, r *http.Request) {
   119  				w.WriteHeader(201)
   120  				w.WriteHeader(202)
   121  				w.Write([]byte("hi"))
   122  			},
   123  			check(hasStatus(201), hasContents("hi")),
   124  		},
   125  		{
   126  			"write sends 200",
   127  			func(w http.ResponseWriter, r *http.Request) {
   128  				w.Write([]byte("hi first"))
   129  				w.WriteHeader(201)
   130  				w.WriteHeader(202)
   131  			},
   132  			check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
   133  		},
   134  		{
   135  			"write string",
   136  			func(w http.ResponseWriter, r *http.Request) {
   137  				io.WriteString(w, "hi first")
   138  			},
   139  			check(
   140  				hasStatus(200),
   141  				hasContents("hi first"),
   142  				hasFlush(false),
   143  				hasHeader("Content-Type", "text/plain; charset=utf-8"),
   144  			),
   145  		},
   146  		{
   147  			"flush",
   148  			func(w http.ResponseWriter, r *http.Request) {
   149  				w.(http.Flusher).Flush() // also sends a 200
   150  				w.WriteHeader(201)
   151  			},
   152  			check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
   153  		},
   154  		{
   155  			"Content-Type detection",
   156  			func(w http.ResponseWriter, r *http.Request) {
   157  				io.WriteString(w, "<html>")
   158  			},
   159  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   160  		},
   161  		{
   162  			"no Content-Type detection with Transfer-Encoding",
   163  			func(w http.ResponseWriter, r *http.Request) {
   164  				w.Header().Set("Transfer-Encoding", "some encoding")
   165  				io.WriteString(w, "<html>")
   166  			},
   167  			check(hasHeader("Content-Type", "")), // no header
   168  		},
   169  		{
   170  			"no Content-Type detection if set explicitly",
   171  			func(w http.ResponseWriter, r *http.Request) {
   172  				w.Header().Set("Content-Type", "some/type")
   173  				io.WriteString(w, "<html>")
   174  			},
   175  			check(hasHeader("Content-Type", "some/type")),
   176  		},
   177  		{
   178  			"Content-Type detection doesn't crash if HeaderMap is nil",
   179  			func(w http.ResponseWriter, r *http.Request) {
   180  				// Act as if the user wrote new(httptest.ResponseRecorder)
   181  				// rather than using NewRecorder (which initializes
   182  				// HeaderMap)
   183  				w.(*ResponseRecorder).HeaderMap = nil
   184  				io.WriteString(w, "<html>")
   185  			},
   186  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   187  		},
   188  		{
   189  			"Header is not changed after write",
   190  			func(w http.ResponseWriter, r *http.Request) {
   191  				hdr := w.Header()
   192  				hdr.Set("Key", "correct")
   193  				w.WriteHeader(200)
   194  				hdr.Set("Key", "incorrect")
   195  			},
   196  			check(hasHeader("Key", "correct")),
   197  		},
   198  		{
   199  			"Trailer headers are correctly recorded",
   200  			func(w http.ResponseWriter, r *http.Request) {
   201  				w.Header().Set("Non-Trailer", "correct")
   202  				w.Header().Set("Trailer", "Trailer-A")
   203  				w.Header().Add("Trailer", "Trailer-B")
   204  				w.Header().Add("Trailer", "Trailer-C")
   205  				io.WriteString(w, "<html>")
   206  				w.Header().Set("Non-Trailer", "incorrect")
   207  				w.Header().Set("Trailer-A", "valuea")
   208  				w.Header().Set("Trailer-C", "valuec")
   209  				w.Header().Set("Trailer-NotDeclared", "should be omitted")
   210  				w.Header().Set("Trailer:Trailer-D", "with prefix")
   211  			},
   212  			check(
   213  				hasStatus(200),
   214  				hasHeader("Content-Type", "text/html; charset=utf-8"),
   215  				hasHeader("Non-Trailer", "correct"),
   216  				hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
   217  				hasTrailer("Trailer-A", "valuea"),
   218  				hasTrailer("Trailer-C", "valuec"),
   219  				hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
   220  				hasTrailer("Trailer-D", "with prefix"),
   221  			),
   222  		},
   223  		{
   224  			"Header set without any write", // Issue 15560
   225  			func(w http.ResponseWriter, r *http.Request) {
   226  				w.Header().Set("X-Foo", "1")
   227  
   228  				// Simulate somebody using
   229  				// new(ResponseRecorder) instead of
   230  				// using the constructor which sets
   231  				// this to 200
   232  				w.(*ResponseRecorder).Code = 0
   233  			},
   234  			check(
   235  				hasOldHeader("X-Foo", "1"),
   236  				hasStatus(0),
   237  				hasHeader("X-Foo", "1"),
   238  				hasResultStatus(200),
   239  			),
   240  		},
   241  		{
   242  			"HeaderMap vs FinalHeaders", // more for Issue 15560
   243  			func(w http.ResponseWriter, r *http.Request) {
   244  				h := w.Header()
   245  				h.Set("X-Foo", "1")
   246  				w.Write([]byte("hi"))
   247  				h.Set("X-Foo", "2")
   248  				h.Set("X-Bar", "2")
   249  			},
   250  			check(
   251  				hasOldHeader("X-Foo", "2"),
   252  				hasOldHeader("X-Bar", "2"),
   253  				hasHeader("X-Foo", "1"),
   254  				hasNotHeaders("X-Bar"),
   255  			),
   256  		},
   257  		{
   258  			"setting Content-Length header",
   259  			func(w http.ResponseWriter, r *http.Request) {
   260  				body := "Some body"
   261  				contentLength := fmt.Sprintf("%d", len(body))
   262  				w.Header().Set("Content-Length", contentLength)
   263  				io.WriteString(w, body)
   264  			},
   265  			check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
   266  		},
   267  	}
   268  	r, _ := http.NewRequest("GET", "http://foo.com/", nil)
   269  	for _, tt := range tests {
   270  		h := http.HandlerFunc(tt.h)
   271  		rec := NewRecorder()
   272  		h.ServeHTTP(rec, r)
   273  		for _, check := range tt.checks {
   274  			if err := check(rec); err != nil {
   275  				t.Errorf("%s: %v", tt.name, err)
   276  			}
   277  		}
   278  	}
   279  }