github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/gmhttp/httptest/recorder_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"testing"
    11  
    12  	http "github.com/hxx258456/ccgo/gmhttp"
    13  )
    14  
    15  func TestRecorder(t *testing.T) {
    16  	type checkFunc func(*ResponseRecorder) error
    17  	check := func(fns ...checkFunc) []checkFunc { return fns }
    18  
    19  	hasStatus := func(wantCode int) checkFunc {
    20  		return func(rec *ResponseRecorder) error {
    21  			if rec.Code != wantCode {
    22  				return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
    23  			}
    24  			return nil
    25  		}
    26  	}
    27  	hasResultStatus := func(want string) checkFunc {
    28  		return func(rec *ResponseRecorder) error {
    29  			if rec.Result().Status != want {
    30  				return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want)
    31  			}
    32  			return nil
    33  		}
    34  	}
    35  	hasResultStatusCode := func(wantCode int) checkFunc {
    36  		return func(rec *ResponseRecorder) error {
    37  			if rec.Result().StatusCode != wantCode {
    38  				return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
    39  			}
    40  			return nil
    41  		}
    42  	}
    43  	hasResultContents := func(want string) checkFunc {
    44  		return func(rec *ResponseRecorder) error {
    45  			contentBytes, err := io.ReadAll(rec.Result().Body)
    46  			if err != nil {
    47  				return err
    48  			}
    49  			contents := string(contentBytes)
    50  			if contents != want {
    51  				return fmt.Errorf("Result().Body = %s; want %s", contents, want)
    52  			}
    53  			return nil
    54  		}
    55  	}
    56  	hasContents := func(want string) checkFunc {
    57  		return func(rec *ResponseRecorder) error {
    58  			if rec.Body.String() != want {
    59  				return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
    60  			}
    61  			return nil
    62  		}
    63  	}
    64  	hasFlush := func(want bool) checkFunc {
    65  		return func(rec *ResponseRecorder) error {
    66  			if rec.Flushed != want {
    67  				return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
    68  			}
    69  			return nil
    70  		}
    71  	}
    72  	hasOldHeader := func(key, want string) checkFunc {
    73  		return func(rec *ResponseRecorder) error {
    74  			if got := rec.HeaderMap.Get(key); got != want {
    75  				return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
    76  			}
    77  			return nil
    78  		}
    79  	}
    80  	hasHeader := func(key, want string) checkFunc {
    81  		return func(rec *ResponseRecorder) error {
    82  			if got := rec.Result().Header.Get(key); got != want {
    83  				return fmt.Errorf("final header %s = %q; want %q", key, got, want)
    84  			}
    85  			return nil
    86  		}
    87  	}
    88  	hasNotHeaders := func(keys ...string) checkFunc {
    89  		return func(rec *ResponseRecorder) error {
    90  			for _, k := range keys {
    91  				v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
    92  				if ok {
    93  					return fmt.Errorf("unexpected header %s with value %q", k, v)
    94  				}
    95  			}
    96  			return nil
    97  		}
    98  	}
    99  	hasTrailer := func(key, want string) checkFunc {
   100  		return func(rec *ResponseRecorder) error {
   101  			if got := rec.Result().Trailer.Get(key); got != want {
   102  				return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
   103  			}
   104  			return nil
   105  		}
   106  	}
   107  	hasNotTrailers := func(keys ...string) checkFunc {
   108  		return func(rec *ResponseRecorder) error {
   109  			trailers := rec.Result().Trailer
   110  			for _, k := range keys {
   111  				_, ok := trailers[http.CanonicalHeaderKey(k)]
   112  				if ok {
   113  					return fmt.Errorf("unexpected trailer %s", k)
   114  				}
   115  			}
   116  			return nil
   117  		}
   118  	}
   119  	hasContentLength := func(length int64) checkFunc {
   120  		return func(rec *ResponseRecorder) error {
   121  			if got := rec.Result().ContentLength; got != length {
   122  				return fmt.Errorf("ContentLength = %d; want %d", got, length)
   123  			}
   124  			return nil
   125  		}
   126  	}
   127  
   128  	for _, tt := range [...]struct {
   129  		name   string
   130  		h      func(w http.ResponseWriter, r *http.Request)
   131  		checks []checkFunc
   132  	}{
   133  		{
   134  			"200 default",
   135  			func(w http.ResponseWriter, r *http.Request) {},
   136  			check(hasStatus(200), hasContents("")),
   137  		},
   138  		{
   139  			"first code only",
   140  			func(w http.ResponseWriter, r *http.Request) {
   141  				w.WriteHeader(201)
   142  				w.WriteHeader(202)
   143  				w.Write([]byte("hi"))
   144  			},
   145  			check(hasStatus(201), hasContents("hi")),
   146  		},
   147  		{
   148  			"write sends 200",
   149  			func(w http.ResponseWriter, r *http.Request) {
   150  				w.Write([]byte("hi first"))
   151  				w.WriteHeader(201)
   152  				w.WriteHeader(202)
   153  			},
   154  			check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
   155  		},
   156  		{
   157  			"write string",
   158  			func(w http.ResponseWriter, r *http.Request) {
   159  				io.WriteString(w, "hi first")
   160  			},
   161  			check(
   162  				hasStatus(200),
   163  				hasContents("hi first"),
   164  				hasFlush(false),
   165  				hasHeader("Content-Type", "text/plain; charset=utf-8"),
   166  			),
   167  		},
   168  		{
   169  			"flush",
   170  			func(w http.ResponseWriter, r *http.Request) {
   171  				w.(http.Flusher).Flush() // also sends a 200
   172  				w.WriteHeader(201)
   173  			},
   174  			check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
   175  		},
   176  		{
   177  			"Content-Type detection",
   178  			func(w http.ResponseWriter, r *http.Request) {
   179  				io.WriteString(w, "<html>")
   180  			},
   181  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   182  		},
   183  		{
   184  			"no Content-Type detection with Transfer-Encoding",
   185  			func(w http.ResponseWriter, r *http.Request) {
   186  				w.Header().Set("Transfer-Encoding", "some encoding")
   187  				io.WriteString(w, "<html>")
   188  			},
   189  			check(hasHeader("Content-Type", "")), // no header
   190  		},
   191  		{
   192  			"no Content-Type detection if set explicitly",
   193  			func(w http.ResponseWriter, r *http.Request) {
   194  				w.Header().Set("Content-Type", "some/type")
   195  				io.WriteString(w, "<html>")
   196  			},
   197  			check(hasHeader("Content-Type", "some/type")),
   198  		},
   199  		{
   200  			"Content-Type detection doesn't crash if HeaderMap is nil",
   201  			func(w http.ResponseWriter, r *http.Request) {
   202  				// Act as if the user wrote new(httptest.ResponseRecorder)
   203  				// rather than using NewRecorder (which initializes
   204  				// HeaderMap)
   205  				w.(*ResponseRecorder).HeaderMap = nil
   206  				io.WriteString(w, "<html>")
   207  			},
   208  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   209  		},
   210  		{
   211  			"Header is not changed after write",
   212  			func(w http.ResponseWriter, r *http.Request) {
   213  				hdr := w.Header()
   214  				hdr.Set("Key", "correct")
   215  				w.WriteHeader(200)
   216  				hdr.Set("Key", "incorrect")
   217  			},
   218  			check(hasHeader("Key", "correct")),
   219  		},
   220  		{
   221  			"Trailer headers are correctly recorded",
   222  			func(w http.ResponseWriter, r *http.Request) {
   223  				w.Header().Set("Non-Trailer", "correct")
   224  				w.Header().Set("Trailer", "Trailer-A")
   225  				w.Header().Add("Trailer", "Trailer-B")
   226  				w.Header().Add("Trailer", "Trailer-C")
   227  				io.WriteString(w, "<html>")
   228  				w.Header().Set("Non-Trailer", "incorrect")
   229  				w.Header().Set("Trailer-A", "valuea")
   230  				w.Header().Set("Trailer-C", "valuec")
   231  				w.Header().Set("Trailer-NotDeclared", "should be omitted")
   232  				w.Header().Set("Trailer:Trailer-D", "with prefix")
   233  			},
   234  			check(
   235  				hasStatus(200),
   236  				hasHeader("Content-Type", "text/html; charset=utf-8"),
   237  				hasHeader("Non-Trailer", "correct"),
   238  				hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
   239  				hasTrailer("Trailer-A", "valuea"),
   240  				hasTrailer("Trailer-C", "valuec"),
   241  				hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
   242  				hasTrailer("Trailer-D", "with prefix"),
   243  			),
   244  		},
   245  		{
   246  			"Header set without any write", // Issue 15560
   247  			func(w http.ResponseWriter, r *http.Request) {
   248  				w.Header().Set("X-Foo", "1")
   249  
   250  				// Simulate somebody using
   251  				// new(ResponseRecorder) instead of
   252  				// using the constructor which sets
   253  				// this to 200
   254  				w.(*ResponseRecorder).Code = 0
   255  			},
   256  			check(
   257  				hasOldHeader("X-Foo", "1"),
   258  				hasStatus(0),
   259  				hasHeader("X-Foo", "1"),
   260  				hasResultStatus("200 OK"),
   261  				hasResultStatusCode(200),
   262  			),
   263  		},
   264  		{
   265  			"HeaderMap vs FinalHeaders", // more for Issue 15560
   266  			func(w http.ResponseWriter, r *http.Request) {
   267  				h := w.Header()
   268  				h.Set("X-Foo", "1")
   269  				w.Write([]byte("hi"))
   270  				h.Set("X-Foo", "2")
   271  				h.Set("X-Bar", "2")
   272  			},
   273  			check(
   274  				hasOldHeader("X-Foo", "2"),
   275  				hasOldHeader("X-Bar", "2"),
   276  				hasHeader("X-Foo", "1"),
   277  				hasNotHeaders("X-Bar"),
   278  			),
   279  		},
   280  		{
   281  			"setting Content-Length header",
   282  			func(w http.ResponseWriter, r *http.Request) {
   283  				body := "Some body"
   284  				contentLength := fmt.Sprintf("%d", len(body))
   285  				w.Header().Set("Content-Length", contentLength)
   286  				io.WriteString(w, body)
   287  			},
   288  			check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
   289  		},
   290  		{
   291  			"nil ResponseRecorder.Body", // Issue 26642
   292  			func(w http.ResponseWriter, r *http.Request) {
   293  				w.(*ResponseRecorder).Body = nil
   294  				io.WriteString(w, "hi")
   295  			},
   296  			check(hasResultContents("")), // check we don't crash reading the body
   297  
   298  		},
   299  	} {
   300  		t.Run(tt.name, func(t *testing.T) {
   301  			r, _ := http.NewRequest("GET", "http://foo.com/", nil)
   302  			h := http.HandlerFunc(tt.h)
   303  			rec := NewRecorder()
   304  			h.ServeHTTP(rec, r)
   305  			for _, check := range tt.checks {
   306  				if err := check(rec); err != nil {
   307  					t.Error(err)
   308  				}
   309  			}
   310  		})
   311  	}
   312  }
   313  
   314  // issue 39017 - disallow Content-Length values such as "+3"
   315  func TestParseContentLength(t *testing.T) {
   316  	tests := []struct {
   317  		cl   string
   318  		want int64
   319  	}{
   320  		{
   321  			cl:   "3",
   322  			want: 3,
   323  		},
   324  		{
   325  			cl:   "+3",
   326  			want: -1,
   327  		},
   328  		{
   329  			cl:   "-3",
   330  			want: -1,
   331  		},
   332  		{
   333  			// max int64, for safe conversion before returning
   334  			cl:   "9223372036854775807",
   335  			want: 9223372036854775807,
   336  		},
   337  		{
   338  			cl:   "9223372036854775808",
   339  			want: -1,
   340  		},
   341  	}
   342  
   343  	for _, tt := range tests {
   344  		if got := parseContentLength(tt.cl); got != tt.want {
   345  			t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want)
   346  		}
   347  	}
   348  }
   349  
   350  // Ensure that httptest.Recorder panics when given a non-3 digit (XXX)
   351  // status HTTP code. See https://golang.org/issues/45353
   352  func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
   353  	badCodes := []int{
   354  		-100, 0, 99, 1000, 20000,
   355  	}
   356  	for _, badCode := range badCodes {
   357  		badCode := badCode
   358  		t.Run(fmt.Sprintf("Code=%d", badCode), func(t *testing.T) {
   359  			defer func() {
   360  				if r := recover(); r == nil {
   361  					t.Fatal("Expected a panic")
   362  				}
   363  			}()
   364  
   365  			handler := func(rw http.ResponseWriter, _ *http.Request) {
   366  				rw.WriteHeader(badCode)
   367  			}
   368  			r, _ := http.NewRequest("GET", "http://example.org/", nil)
   369  			rw := NewRecorder()
   370  			handler(rw, r)
   371  		})
   372  	}
   373  }