gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/gmhttp/httptest/recorder.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"io"
    11  	"net/textproto"
    12  	"strconv"
    13  	"strings"
    14  
    15  	http "gitee.com/ks-custle/core-gm/gmhttp"
    16  
    17  	"golang.org/x/net/http/httpguts"
    18  )
    19  
    20  // ResponseRecorder is an implementation of http.ResponseWriter that
    21  // records its mutations for later inspection in tests.
    22  type ResponseRecorder struct {
    23  	// Code is the HTTP response code set by WriteHeader.
    24  	//
    25  	// Note that if a Handler never calls WriteHeader or Write,
    26  	// this might end up being 0, rather than the implicit
    27  	// http.StatusOK. To get the implicit value, use the Result
    28  	// method.
    29  	Code int
    30  
    31  	// HeaderMap contains the headers explicitly set by the Handler.
    32  	// It is an internal detail.
    33  	//
    34  	// ToDeprecated: HeaderMap exists for historical compatibility
    35  	// and should not be used. To access the headers returned by a handler,
    36  	// use the Response.Header map as returned by the Result method.
    37  	HeaderMap http.Header
    38  
    39  	// Body is the buffer to which the Handler's Write calls are sent.
    40  	// If nil, the Writes are silently discarded.
    41  	Body *bytes.Buffer
    42  
    43  	// Flushed is whether the Handler called Flush.
    44  	Flushed bool
    45  
    46  	result      *http.Response // cache of Result's return value
    47  	snapHeader  http.Header    // snapshot of HeaderMap at first Write
    48  	wroteHeader bool
    49  }
    50  
    51  // NewRecorder returns an initialized ResponseRecorder.
    52  func NewRecorder() *ResponseRecorder {
    53  	//goland:noinspection GoDeprecation
    54  	return &ResponseRecorder{
    55  		HeaderMap: make(http.Header),
    56  		Body:      new(bytes.Buffer),
    57  		Code:      200,
    58  	}
    59  }
    60  
    61  // DefaultRemoteAddr is the default remote address to return in RemoteAddr if
    62  // an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
    63  //
    64  //goland:noinspection GoUnusedConst
    65  const DefaultRemoteAddr = "1.2.3.4"
    66  
    67  // Header implements http.ResponseWriter. It returns the response
    68  // headers to mutate within a handler. To test the headers that were
    69  // written after a handler completes, use the Result method and see
    70  // the returned Response value's Header.
    71  func (rw *ResponseRecorder) Header() http.Header {
    72  	//goland:noinspection GoDeprecation
    73  	m := rw.HeaderMap
    74  	if m == nil {
    75  		m = make(http.Header)
    76  		//goland:noinspection GoDeprecation
    77  		rw.HeaderMap = m
    78  	}
    79  	return m
    80  }
    81  
    82  // writeHeader writes a header if it was not written yet and
    83  // detects Content-Type if needed.
    84  //
    85  // bytes or str are the beginning of the response body.
    86  // We pass both to avoid unnecessarily generate garbage
    87  // in rw.WriteString which was created for performance reasons.
    88  // Non-nil bytes win.
    89  func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
    90  	if rw.wroteHeader {
    91  		return
    92  	}
    93  	if len(str) > 512 {
    94  		str = str[:512]
    95  	}
    96  
    97  	m := rw.Header()
    98  
    99  	_, hasType := m["Content-Type"]
   100  	hasTE := m.Get("Transfer-Encoding") != ""
   101  	if !hasType && !hasTE {
   102  		if b == nil {
   103  			b = []byte(str)
   104  		}
   105  		m.Set("Content-Type", http.DetectContentType(b))
   106  	}
   107  
   108  	rw.WriteHeader(200)
   109  }
   110  
   111  // Write implements http.ResponseWriter. The data in buf is written to
   112  // rw.Body, if not nil.
   113  func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
   114  	rw.writeHeader(buf, "")
   115  	if rw.Body != nil {
   116  		rw.Body.Write(buf)
   117  	}
   118  	return len(buf), nil
   119  }
   120  
   121  // WriteString implements io.StringWriter. The data in str is written
   122  // to rw.Body, if not nil.
   123  func (rw *ResponseRecorder) WriteString(str string) (int, error) {
   124  	rw.writeHeader(nil, str)
   125  	if rw.Body != nil {
   126  		rw.Body.WriteString(str)
   127  	}
   128  	return len(str), nil
   129  }
   130  
   131  func checkWriteHeaderCode(code int) {
   132  	// Issue 22880: require valid WriteHeader status codes.
   133  	// For now we only enforce that it's three digits.
   134  	// In the future we might block things over 599 (600 and above aren't defined
   135  	// at https://httpwg.org/specs/rfc7231.html#status.codes)
   136  	// and we might block under 200 (once we have more mature 1xx support).
   137  	// But for now any three digits.
   138  	//
   139  	// We used to send "HTTP/1.1 000 0" on the wire in responses but there's
   140  	// no equivalent bogus thing we can realistically send in HTTP/2,
   141  	// so we'll consistently panic instead and help people find their bugs
   142  	// early. (We can't return an error from WriteHeader even if we wanted to.)
   143  	if code < 100 || code > 999 {
   144  		panic(fmt.Sprintf("invalid WriteHeader code %v", code))
   145  	}
   146  }
   147  
   148  // WriteHeader implements http.ResponseWriter.
   149  //
   150  //goland:noinspection GoDeprecation
   151  func (rw *ResponseRecorder) WriteHeader(code int) {
   152  	if rw.wroteHeader {
   153  		return
   154  	}
   155  
   156  	checkWriteHeaderCode(code)
   157  	rw.Code = code
   158  	rw.wroteHeader = true
   159  	if rw.HeaderMap == nil {
   160  		rw.HeaderMap = make(http.Header)
   161  	}
   162  	rw.snapHeader = rw.HeaderMap.Clone()
   163  }
   164  
   165  // Flush implements http.Flusher. To test whether Flush was
   166  // called, see rw.Flushed.
   167  func (rw *ResponseRecorder) Flush() {
   168  	if !rw.wroteHeader {
   169  		rw.WriteHeader(200)
   170  	}
   171  	rw.Flushed = true
   172  }
   173  
   174  // Result returns the response generated by the handler.
   175  //
   176  // The returned Response will have at least its StatusCode,
   177  // Header, Body, and optionally Trailer populated.
   178  // More fields may be populated in the future, so callers should
   179  // not DeepEqual the result in tests.
   180  //
   181  // The Response.Header is a snapshot of the headers at the time of the
   182  // first write call, or at the time of this call, if the handler never
   183  // did a write.
   184  //
   185  // The Response.Body is guaranteed to be non-nil and Body.Read call is
   186  // guaranteed to not return any error other than io.EOF.
   187  //
   188  // Result must only be called after the handler has finished running.
   189  //
   190  //goland:noinspection GoDeprecation
   191  func (rw *ResponseRecorder) Result() *http.Response {
   192  	if rw.result != nil {
   193  		return rw.result
   194  	}
   195  	if rw.snapHeader == nil {
   196  		rw.snapHeader = rw.HeaderMap.Clone()
   197  	}
   198  	res := &http.Response{
   199  		Proto:      "HTTP/1.1",
   200  		ProtoMajor: 1,
   201  		ProtoMinor: 1,
   202  		StatusCode: rw.Code,
   203  		Header:     rw.snapHeader,
   204  	}
   205  	rw.result = res
   206  	if res.StatusCode == 0 {
   207  		res.StatusCode = 200
   208  	}
   209  	res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
   210  	if rw.Body != nil {
   211  		res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
   212  	} else {
   213  		res.Body = http.NoBody
   214  	}
   215  	res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
   216  
   217  	if trailers, ok := rw.snapHeader["Trailer"]; ok {
   218  		res.Trailer = make(http.Header, len(trailers))
   219  		for _, k := range trailers {
   220  			k = http.CanonicalHeaderKey(k)
   221  			if !httpguts.ValidTrailerHeader(k) {
   222  				// Ignore since forbidden by RFC 7230, section 4.1.2.
   223  				continue
   224  			}
   225  			vv, ok := rw.HeaderMap[k]
   226  			if !ok {
   227  				continue
   228  			}
   229  			vv2 := make([]string, len(vv))
   230  			copy(vv2, vv)
   231  			res.Trailer[k] = vv2
   232  		}
   233  	}
   234  	for k, vv := range rw.HeaderMap {
   235  		if !strings.HasPrefix(k, http.TrailerPrefix) {
   236  			continue
   237  		}
   238  		if res.Trailer == nil {
   239  			res.Trailer = make(http.Header)
   240  		}
   241  		for _, v := range vv {
   242  			res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
   243  		}
   244  	}
   245  	return res
   246  }
   247  
   248  // parseContentLength trims whitespace from s and returns -1 if no value
   249  // is set, or the value if it's >= 0.
   250  //
   251  // This a modified version of same function found in net/http/transfer.go. This
   252  // one just ignores an invalid header.
   253  func parseContentLength(cl string) int64 {
   254  	cl = textproto.TrimString(cl)
   255  	if cl == "" {
   256  		return -1
   257  	}
   258  	n, err := strconv.ParseUint(cl, 10, 63)
   259  	if err != nil {
   260  		return -1
   261  	}
   262  	return int64(n)
   263  }