github.com/miolini/go@v0.0.0-20160405192216-fca68c8cb408/src/net/http/httptest/recorder_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"testing"
    12  )
    13  
    14  func TestRecorder(t *testing.T) {
    15  	type checkFunc func(*ResponseRecorder) error
    16  	check := func(fns ...checkFunc) []checkFunc { return fns }
    17  
    18  	hasStatus := func(wantCode int) checkFunc {
    19  		return func(rec *ResponseRecorder) error {
    20  			if rec.Code != wantCode {
    21  				return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
    22  			}
    23  			return nil
    24  		}
    25  	}
    26  	hasContents := func(want string) checkFunc {
    27  		return func(rec *ResponseRecorder) error {
    28  			if rec.Body.String() != want {
    29  				return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
    30  			}
    31  			return nil
    32  		}
    33  	}
    34  	hasFlush := func(want bool) checkFunc {
    35  		return func(rec *ResponseRecorder) error {
    36  			if rec.Flushed != want {
    37  				return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
    38  			}
    39  			return nil
    40  		}
    41  	}
    42  	hasHeader := func(key, want string) checkFunc {
    43  		return func(rec *ResponseRecorder) error {
    44  			if got := rec.HeaderMap.Get(key); got != want {
    45  				return fmt.Errorf("header %s = %q; want %q", key, got, want)
    46  			}
    47  			return nil
    48  		}
    49  	}
    50  	hasNotHeaders := func(keys ...string) checkFunc {
    51  		return func(rec *ResponseRecorder) error {
    52  			for _, k := range keys {
    53  				_, ok := rec.HeaderMap[http.CanonicalHeaderKey(k)]
    54  				if ok {
    55  					return fmt.Errorf("unexpected header %s", k)
    56  				}
    57  			}
    58  			return nil
    59  		}
    60  	}
    61  	hasTrailer := func(key, want string) checkFunc {
    62  		return func(rec *ResponseRecorder) error {
    63  			if got := rec.Trailers().Get(key); got != want {
    64  				return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
    65  			}
    66  			return nil
    67  		}
    68  	}
    69  	hasNotTrailers := func(keys ...string) checkFunc {
    70  		return func(rec *ResponseRecorder) error {
    71  			trailers := rec.Trailers()
    72  			for _, k := range keys {
    73  				_, ok := trailers[http.CanonicalHeaderKey(k)]
    74  				if ok {
    75  					return fmt.Errorf("unexpected trailer %s", k)
    76  				}
    77  			}
    78  			return nil
    79  		}
    80  	}
    81  
    82  	tests := []struct {
    83  		name   string
    84  		h      func(w http.ResponseWriter, r *http.Request)
    85  		checks []checkFunc
    86  	}{
    87  		{
    88  			"200 default",
    89  			func(w http.ResponseWriter, r *http.Request) {},
    90  			check(hasStatus(200), hasContents("")),
    91  		},
    92  		{
    93  			"first code only",
    94  			func(w http.ResponseWriter, r *http.Request) {
    95  				w.WriteHeader(201)
    96  				w.WriteHeader(202)
    97  				w.Write([]byte("hi"))
    98  			},
    99  			check(hasStatus(201), hasContents("hi")),
   100  		},
   101  		{
   102  			"write sends 200",
   103  			func(w http.ResponseWriter, r *http.Request) {
   104  				w.Write([]byte("hi first"))
   105  				w.WriteHeader(201)
   106  				w.WriteHeader(202)
   107  			},
   108  			check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
   109  		},
   110  		{
   111  			"write string",
   112  			func(w http.ResponseWriter, r *http.Request) {
   113  				io.WriteString(w, "hi first")
   114  			},
   115  			check(
   116  				hasStatus(200),
   117  				hasContents("hi first"),
   118  				hasFlush(false),
   119  				hasHeader("Content-Type", "text/plain; charset=utf-8"),
   120  			),
   121  		},
   122  		{
   123  			"flush",
   124  			func(w http.ResponseWriter, r *http.Request) {
   125  				w.(http.Flusher).Flush() // also sends a 200
   126  				w.WriteHeader(201)
   127  			},
   128  			check(hasStatus(200), hasFlush(true)),
   129  		},
   130  		{
   131  			"Content-Type detection",
   132  			func(w http.ResponseWriter, r *http.Request) {
   133  				io.WriteString(w, "<html>")
   134  			},
   135  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   136  		},
   137  		{
   138  			"no Content-Type detection with Transfer-Encoding",
   139  			func(w http.ResponseWriter, r *http.Request) {
   140  				w.Header().Set("Transfer-Encoding", "some encoding")
   141  				io.WriteString(w, "<html>")
   142  			},
   143  			check(hasHeader("Content-Type", "")), // no header
   144  		},
   145  		{
   146  			"no Content-Type detection if set explicitly",
   147  			func(w http.ResponseWriter, r *http.Request) {
   148  				w.Header().Set("Content-Type", "some/type")
   149  				io.WriteString(w, "<html>")
   150  			},
   151  			check(hasHeader("Content-Type", "some/type")),
   152  		},
   153  		{
   154  			"Content-Type detection doesn't crash if HeaderMap is nil",
   155  			func(w http.ResponseWriter, r *http.Request) {
   156  				// Act as if the user wrote new(httptest.ResponseRecorder)
   157  				// rather than using NewRecorder (which initializes
   158  				// HeaderMap)
   159  				w.(*ResponseRecorder).HeaderMap = nil
   160  				io.WriteString(w, "<html>")
   161  			},
   162  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   163  		},
   164  		{
   165  			"Header is not changed after write",
   166  			func(w http.ResponseWriter, r *http.Request) {
   167  				hdr := w.Header()
   168  				hdr.Set("Key", "correct")
   169  				w.WriteHeader(200)
   170  				hdr.Set("Key", "incorrect")
   171  			},
   172  			check(hasHeader("Key", "correct")),
   173  		},
   174  		{
   175  			"Trailer headers are correctly recorded",
   176  			func(w http.ResponseWriter, r *http.Request) {
   177  				w.Header().Set("Non-Trailer", "correct")
   178  				w.Header().Set("Trailer", "Trailer-A")
   179  				w.Header().Add("Trailer", "Trailer-B")
   180  				w.Header().Add("Trailer", "Trailer-C")
   181  				io.WriteString(w, "<html>")
   182  				w.Header().Set("Non-Trailer", "incorrect")
   183  				w.Header().Set("Trailer-A", "valuea")
   184  				w.Header().Set("Trailer-C", "valuec")
   185  				w.Header().Set("Trailer-NotDeclared", "should be omitted")
   186  			},
   187  			check(
   188  				hasStatus(200),
   189  				hasHeader("Content-Type", "text/html; charset=utf-8"),
   190  				hasHeader("Non-Trailer", "correct"),
   191  				hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
   192  				hasTrailer("Trailer-A", "valuea"),
   193  				hasTrailer("Trailer-C", "valuec"),
   194  				hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
   195  			),
   196  		},
   197  	}
   198  	r, _ := http.NewRequest("GET", "http://foo.com/", nil)
   199  	for _, tt := range tests {
   200  		h := http.HandlerFunc(tt.h)
   201  		rec := NewRecorder()
   202  		h.ServeHTTP(rec, r)
   203  		for _, check := range tt.checks {
   204  			if err := check(rec); err != nil {
   205  				t.Errorf("%s: %v", tt.name, err)
   206  			}
   207  		}
   208  	}
   209  }