gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/gmhttp/httptest/recorder_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"testing"
    11  
    12  	http "gitee.com/ks-custle/core-gm/gmhttp"
    13  )
    14  
    15  func TestRecorder(t *testing.T) {
    16  	type checkFunc func(*ResponseRecorder) error
    17  	check := func(fns ...checkFunc) []checkFunc { return fns }
    18  
    19  	hasStatus := func(wantCode int) checkFunc {
    20  		return func(rec *ResponseRecorder) error {
    21  			if rec.Code != wantCode {
    22  				//goland:noinspection GoErrorStringFormat
    23  				return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
    24  			}
    25  			return nil
    26  		}
    27  	}
    28  	hasResultStatus := func(want string) checkFunc {
    29  		return func(rec *ResponseRecorder) error {
    30  			if rec.Result().Status != want {
    31  				return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want)
    32  			}
    33  			return nil
    34  		}
    35  	}
    36  	hasResultStatusCode := func(wantCode int) checkFunc {
    37  		return func(rec *ResponseRecorder) error {
    38  			if rec.Result().StatusCode != wantCode {
    39  				return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
    40  			}
    41  			return nil
    42  		}
    43  	}
    44  	hasResultContents := func(want string) checkFunc {
    45  		return func(rec *ResponseRecorder) error {
    46  			contentBytes, err := io.ReadAll(rec.Result().Body)
    47  			if err != nil {
    48  				return err
    49  			}
    50  			contents := string(contentBytes)
    51  			if contents != want {
    52  				return fmt.Errorf("Result().Body = %s; want %s", contents, want)
    53  			}
    54  			return nil
    55  		}
    56  	}
    57  	hasContents := func(want string) checkFunc {
    58  		return func(rec *ResponseRecorder) error {
    59  			if rec.Body.String() != want {
    60  				return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
    61  			}
    62  			return nil
    63  		}
    64  	}
    65  	hasFlush := func(want bool) checkFunc {
    66  		return func(rec *ResponseRecorder) error {
    67  			if rec.Flushed != want {
    68  				//goland:noinspection GoErrorStringFormat
    69  				return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
    70  			}
    71  			return nil
    72  		}
    73  	}
    74  	hasOldHeader := func(key, want string) checkFunc {
    75  		return func(rec *ResponseRecorder) error {
    76  			//goland:noinspection GoDeprecation
    77  			if got := rec.HeaderMap.Get(key); got != want {
    78  				return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
    79  			}
    80  			return nil
    81  		}
    82  	}
    83  	hasHeader := func(key, want string) checkFunc {
    84  		return func(rec *ResponseRecorder) error {
    85  			if got := rec.Result().Header.Get(key); got != want {
    86  				return fmt.Errorf("final header %s = %q; want %q", key, got, want)
    87  			}
    88  			return nil
    89  		}
    90  	}
    91  	hasNotHeaders := func(keys ...string) checkFunc {
    92  		return func(rec *ResponseRecorder) error {
    93  			for _, k := range keys {
    94  				v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
    95  				if ok {
    96  					return fmt.Errorf("unexpected header %s with value %q", k, v)
    97  				}
    98  			}
    99  			return nil
   100  		}
   101  	}
   102  	hasTrailer := func(key, want string) checkFunc {
   103  		return func(rec *ResponseRecorder) error {
   104  			if got := rec.Result().Trailer.Get(key); got != want {
   105  				return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
   106  			}
   107  			return nil
   108  		}
   109  	}
   110  	hasNotTrailers := func(keys ...string) checkFunc {
   111  		return func(rec *ResponseRecorder) error {
   112  			trailers := rec.Result().Trailer
   113  			for _, k := range keys {
   114  				_, ok := trailers[http.CanonicalHeaderKey(k)]
   115  				if ok {
   116  					return fmt.Errorf("unexpected trailer %s", k)
   117  				}
   118  			}
   119  			return nil
   120  		}
   121  	}
   122  	hasContentLength := func(length int64) checkFunc {
   123  		return func(rec *ResponseRecorder) error {
   124  			if got := rec.Result().ContentLength; got != length {
   125  				return fmt.Errorf("ContentLength = %d; want %d", got, length)
   126  			}
   127  			return nil
   128  		}
   129  	}
   130  
   131  	for _, tt := range [...]struct {
   132  		name   string
   133  		h      func(w http.ResponseWriter, r *http.Request)
   134  		checks []checkFunc
   135  	}{
   136  		{
   137  			"200 default",
   138  			func(w http.ResponseWriter, r *http.Request) {},
   139  			check(hasStatus(200), hasContents("")),
   140  		},
   141  		{
   142  			"first code only",
   143  			func(w http.ResponseWriter, r *http.Request) {
   144  				w.WriteHeader(201)
   145  				w.WriteHeader(202)
   146  				_, _ = w.Write([]byte("hi"))
   147  			},
   148  			check(hasStatus(201), hasContents("hi")),
   149  		},
   150  		{
   151  			"write sends 200",
   152  			func(w http.ResponseWriter, r *http.Request) {
   153  				_, _ = w.Write([]byte("hi first"))
   154  				w.WriteHeader(201)
   155  				w.WriteHeader(202)
   156  			},
   157  			check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
   158  		},
   159  		{
   160  			"write string",
   161  			func(w http.ResponseWriter, r *http.Request) {
   162  				_, _ = io.WriteString(w, "hi first")
   163  			},
   164  			check(
   165  				hasStatus(200),
   166  				hasContents("hi first"),
   167  				hasFlush(false),
   168  				hasHeader("Content-Type", "text/plain; charset=utf-8"),
   169  			),
   170  		},
   171  		{
   172  			"flush",
   173  			func(w http.ResponseWriter, r *http.Request) {
   174  				w.(http.Flusher).Flush() // also sends a 200
   175  				w.WriteHeader(201)
   176  			},
   177  			check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
   178  		},
   179  		{
   180  			"Content-Type detection",
   181  			func(w http.ResponseWriter, r *http.Request) {
   182  				_, _ = io.WriteString(w, "<html>")
   183  			},
   184  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   185  		},
   186  		{
   187  			"no Content-Type detection with Transfer-Encoding",
   188  			func(w http.ResponseWriter, r *http.Request) {
   189  				w.Header().Set("Transfer-Encoding", "some encoding")
   190  				_, _ = io.WriteString(w, "<html>")
   191  			},
   192  			check(hasHeader("Content-Type", "")), // no header
   193  		},
   194  		{
   195  			"no Content-Type detection if set explicitly",
   196  			func(w http.ResponseWriter, r *http.Request) {
   197  				w.Header().Set("Content-Type", "some/type")
   198  				_, _ = io.WriteString(w, "<html>")
   199  			},
   200  			check(hasHeader("Content-Type", "some/type")),
   201  		},
   202  		{
   203  			"Content-Type detection doesn't crash if HeaderMap is nil",
   204  			func(w http.ResponseWriter, r *http.Request) {
   205  				// Act as if the user wrote new(httptest.ResponseRecorder)
   206  				// rather than using NewRecorder (which initializes
   207  				// HeaderMap)
   208  				//goland:noinspection GoDeprecation
   209  				w.(*ResponseRecorder).HeaderMap = nil
   210  				_, _ = io.WriteString(w, "<html>")
   211  			},
   212  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   213  		},
   214  		{
   215  			"Header is not changed after write",
   216  			func(w http.ResponseWriter, r *http.Request) {
   217  				hdr := w.Header()
   218  				hdr.Set("Key", "correct")
   219  				w.WriteHeader(200)
   220  				hdr.Set("Key", "incorrect")
   221  			},
   222  			check(hasHeader("Key", "correct")),
   223  		},
   224  		{
   225  			"Trailer headers are correctly recorded",
   226  			func(w http.ResponseWriter, r *http.Request) {
   227  				w.Header().Set("Non-Trailer", "correct")
   228  				w.Header().Set("Trailer", "Trailer-A")
   229  				w.Header().Add("Trailer", "Trailer-B")
   230  				w.Header().Add("Trailer", "Trailer-C")
   231  				_, _ = io.WriteString(w, "<html>")
   232  				w.Header().Set("Non-Trailer", "incorrect")
   233  				w.Header().Set("Trailer-A", "valuea")
   234  				w.Header().Set("Trailer-C", "valuec")
   235  				w.Header().Set("Trailer-NotDeclared", "should be omitted")
   236  				w.Header().Set("Trailer:Trailer-D", "with prefix")
   237  			},
   238  			check(
   239  				hasStatus(200),
   240  				hasHeader("Content-Type", "text/html; charset=utf-8"),
   241  				hasHeader("Non-Trailer", "correct"),
   242  				hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
   243  				hasTrailer("Trailer-A", "valuea"),
   244  				hasTrailer("Trailer-C", "valuec"),
   245  				hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
   246  				hasTrailer("Trailer-D", "with prefix"),
   247  			),
   248  		},
   249  		{
   250  			"Header set without any write", // Issue 15560
   251  			func(w http.ResponseWriter, r *http.Request) {
   252  				w.Header().Set("X-Foo", "1")
   253  
   254  				// Simulate somebody using
   255  				// new(ResponseRecorder) instead of
   256  				// using the constructor which sets
   257  				// this to 200
   258  				w.(*ResponseRecorder).Code = 0
   259  			},
   260  			check(
   261  				hasOldHeader("X-Foo", "1"),
   262  				hasStatus(0),
   263  				hasHeader("X-Foo", "1"),
   264  				hasResultStatus("200 OK"),
   265  				hasResultStatusCode(200),
   266  			),
   267  		},
   268  		{
   269  			"HeaderMap vs FinalHeaders", // more for Issue 15560
   270  			func(w http.ResponseWriter, r *http.Request) {
   271  				h := w.Header()
   272  				h.Set("X-Foo", "1")
   273  				_, _ = w.Write([]byte("hi"))
   274  				h.Set("X-Foo", "2")
   275  				h.Set("X-Bar", "2")
   276  			},
   277  			check(
   278  				hasOldHeader("X-Foo", "2"),
   279  				hasOldHeader("X-Bar", "2"),
   280  				hasHeader("X-Foo", "1"),
   281  				hasNotHeaders("X-Bar"),
   282  			),
   283  		},
   284  		{
   285  			"setting Content-Length header",
   286  			func(w http.ResponseWriter, r *http.Request) {
   287  				body := "Some body"
   288  				contentLength := fmt.Sprintf("%d", len(body))
   289  				w.Header().Set("Content-Length", contentLength)
   290  				_, _ = io.WriteString(w, body)
   291  			},
   292  			check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
   293  		},
   294  		{
   295  			"nil ResponseRecorder.Body", // Issue 26642
   296  			func(w http.ResponseWriter, r *http.Request) {
   297  				w.(*ResponseRecorder).Body = nil
   298  				_, _ = io.WriteString(w, "hi")
   299  			},
   300  			check(hasResultContents("")), // check we don't crash reading the body
   301  
   302  		},
   303  	} {
   304  		t.Run(tt.name, func(t *testing.T) {
   305  			//goland:noinspection HttpUrlsUsage
   306  			r, _ := http.NewRequest("GET", "http://foo.com/", nil)
   307  			h := http.HandlerFunc(tt.h)
   308  			rec := NewRecorder()
   309  			h.ServeHTTP(rec, r)
   310  			for _, check := range tt.checks {
   311  				if err := check(rec); err != nil {
   312  					t.Error(err)
   313  				}
   314  			}
   315  		})
   316  	}
   317  }
   318  
   319  // issue 39017 - disallow Content-Length values such as "+3"
   320  func TestParseContentLength(t *testing.T) {
   321  	tests := []struct {
   322  		cl   string
   323  		want int64
   324  	}{
   325  		{
   326  			cl:   "3",
   327  			want: 3,
   328  		},
   329  		{
   330  			cl:   "+3",
   331  			want: -1,
   332  		},
   333  		{
   334  			cl:   "-3",
   335  			want: -1,
   336  		},
   337  		{
   338  			// max int64, for safe conversion before returning
   339  			cl:   "9223372036854775807",
   340  			want: 9223372036854775807,
   341  		},
   342  		{
   343  			cl:   "9223372036854775808",
   344  			want: -1,
   345  		},
   346  	}
   347  
   348  	for _, tt := range tests {
   349  		if got := parseContentLength(tt.cl); got != tt.want {
   350  			t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want)
   351  		}
   352  	}
   353  }
   354  
   355  // Ensure that httptest.Recorder panics when given a non-3 digit (XXX)
   356  // status HTTP code. See https://golang.org/issues/45353
   357  func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
   358  	badCodes := []int{
   359  		-100, 0, 99, 1000, 20000,
   360  	}
   361  	for _, badCode := range badCodes {
   362  		badCode := badCode
   363  		t.Run(fmt.Sprintf("Code=%d", badCode), func(t *testing.T) {
   364  			defer func() {
   365  				if r := recover(); r == nil {
   366  					t.Fatal("Expected a panic")
   367  				}
   368  			}()
   369  
   370  			handler := func(rw http.ResponseWriter, _ *http.Request) {
   371  				rw.WriteHeader(badCode)
   372  			}
   373  			//goland:noinspection HttpUrlsUsage
   374  			r, _ := http.NewRequest("GET", "http://example.org/", nil)
   375  			rw := NewRecorder()
   376  			handler(rw, r)
   377  		})
   378  	}
   379  }