github.com/useflyent/fhttp@v0.0.0-20211004035111-333f430cfbbf/httptest/recorder.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"io"
    11  	"net/textproto"
    12  	"strconv"
    13  	"strings"
    14  
    15  	http "github.com/useflyent/fhttp"
    16  
    17  	"golang.org/x/net/http/httpguts"
    18  )
    19  
    20  // ResponseRecorder is an implementation of http.ResponseWriter that
    21  // records its mutations for later inspection in tests.
    22  type ResponseRecorder struct {
    23  	// Code is the HTTP response code set by WriteHeader.
    24  	//
    25  	// Note that if a Handler never calls WriteHeader or Write,
    26  	// this might end up being 0, rather than the implicit
    27  	// http.StatusOK. To get the implicit value, use the Result
    28  	// method.
    29  	Code int
    30  
    31  	// HeaderMap contains the headers explicitly set by the Handler.
    32  	// It is an internal detail.
    33  	//
    34  	// Deprecated: HeaderMap exists for historical compatibility
    35  	// and should not be used. To access the headers returned by a handler,
    36  	// use the Response.Header map as returned by the Result method.
    37  	HeaderMap http.Header
    38  
    39  	// Body is the buffer to which the Handler's Write calls are sent.
    40  	// If nil, the Writes are silently discarded.
    41  	Body *bytes.Buffer
    42  
    43  	// Flushed is whether the Handler called Flush.
    44  	Flushed bool
    45  
    46  	result      *http.Response // cache of Result's return value
    47  	snapHeader  http.Header    // snapshot of HeaderMap at first Write
    48  	wroteHeader bool
    49  }
    50  
    51  // NewRecorder returns an initialized ResponseRecorder.
    52  func NewRecorder() *ResponseRecorder {
    53  	return &ResponseRecorder{
    54  		HeaderMap: make(http.Header),
    55  		Body:      new(bytes.Buffer),
    56  		Code:      200,
    57  	}
    58  }
    59  
    60  // DefaultRemoteAddr is the default remote address to return in RemoteAddr if
    61  // an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
    62  const DefaultRemoteAddr = "1.2.3.4"
    63  
    64  // Header implements http.ResponseWriter. It returns the response
    65  // headers to mutate within a handler. To test the headers that were
    66  // written after a handler completes, use the Result method and see
    67  // the returned Response value's Header.
    68  func (rw *ResponseRecorder) Header() http.Header {
    69  	m := rw.HeaderMap
    70  	if m == nil {
    71  		m = make(http.Header)
    72  		rw.HeaderMap = m
    73  	}
    74  	return m
    75  }
    76  
    77  // writeHeader writes a header if it was not written yet and
    78  // detects Content-Type if needed.
    79  //
    80  // bytes or str are the beginning of the response body.
    81  // We pass both to avoid unnecessarily generate garbage
    82  // in rw.WriteString which was created for performance reasons.
    83  // Non-nil bytes win.
    84  func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
    85  	if rw.wroteHeader {
    86  		return
    87  	}
    88  	if len(str) > 512 {
    89  		str = str[:512]
    90  	}
    91  
    92  	m := rw.Header()
    93  
    94  	_, hasType := m["Content-Type"]
    95  	hasTE := m.Get("Transfer-Encoding") != ""
    96  	if !hasType && !hasTE {
    97  		if b == nil {
    98  			b = []byte(str)
    99  		}
   100  		m.Set("Content-Type", http.DetectContentType(b))
   101  	}
   102  
   103  	rw.WriteHeader(200)
   104  }
   105  
   106  // Write implements http.ResponseWriter. The data in buf is written to
   107  // rw.Body, if not nil.
   108  func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
   109  	rw.writeHeader(buf, "")
   110  	if rw.Body != nil {
   111  		rw.Body.Write(buf)
   112  	}
   113  	return len(buf), nil
   114  }
   115  
   116  // WriteString implements io.StringWriter. The data in str is written
   117  // to rw.Body, if not nil.
   118  func (rw *ResponseRecorder) WriteString(str string) (int, error) {
   119  	rw.writeHeader(nil, str)
   120  	if rw.Body != nil {
   121  		rw.Body.WriteString(str)
   122  	}
   123  	return len(str), nil
   124  }
   125  
   126  // WriteHeader implements http.ResponseWriter.
   127  func (rw *ResponseRecorder) WriteHeader(code int) {
   128  	if rw.wroteHeader {
   129  		return
   130  	}
   131  	rw.Code = code
   132  	rw.wroteHeader = true
   133  	if rw.HeaderMap == nil {
   134  		rw.HeaderMap = make(http.Header)
   135  	}
   136  	rw.snapHeader = rw.HeaderMap.Clone()
   137  }
   138  
   139  // Flush implements http.Flusher. To test whether Flush was
   140  // called, see rw.Flushed.
   141  func (rw *ResponseRecorder) Flush() {
   142  	if !rw.wroteHeader {
   143  		rw.WriteHeader(200)
   144  	}
   145  	rw.Flushed = true
   146  }
   147  
   148  // Result returns the response generated by the handler.
   149  //
   150  // The returned Response will have at least its StatusCode,
   151  // Header, Body, and optionally Trailer populated.
   152  // More fields may be populated in the future, so callers should
   153  // not DeepEqual the result in tests.
   154  //
   155  // The Response.Header is a snapshot of the headers at the time of the
   156  // first write call, or at the time of this call, if the handler never
   157  // did a write.
   158  //
   159  // The Response.Body is guaranteed to be non-nil and Body.Read call is
   160  // guaranteed to not return any error other than io.EOF.
   161  //
   162  // Result must only be called after the handler has finished running.
   163  func (rw *ResponseRecorder) Result() *http.Response {
   164  	if rw.result != nil {
   165  		return rw.result
   166  	}
   167  	if rw.snapHeader == nil {
   168  		rw.snapHeader = rw.HeaderMap.Clone()
   169  	}
   170  	res := &http.Response{
   171  		Proto:      "HTTP/1.1",
   172  		ProtoMajor: 1,
   173  		ProtoMinor: 1,
   174  		StatusCode: rw.Code,
   175  		Header:     rw.snapHeader,
   176  	}
   177  	rw.result = res
   178  	if res.StatusCode == 0 {
   179  		res.StatusCode = 200
   180  	}
   181  	res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
   182  	if rw.Body != nil {
   183  		res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
   184  	} else {
   185  		res.Body = http.NoBody
   186  	}
   187  	res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
   188  
   189  	if trailers, ok := rw.snapHeader["Trailer"]; ok {
   190  		res.Trailer = make(http.Header, len(trailers))
   191  		for _, k := range trailers {
   192  			k = http.CanonicalHeaderKey(k)
   193  			if !httpguts.ValidTrailerHeader(k) {
   194  				// Ignore since forbidden by RFC 7230, section 4.1.2.
   195  				continue
   196  			}
   197  			vv, ok := rw.HeaderMap[k]
   198  			if !ok {
   199  				continue
   200  			}
   201  			vv2 := make([]string, len(vv))
   202  			copy(vv2, vv)
   203  			res.Trailer[k] = vv2
   204  		}
   205  	}
   206  	for k, vv := range rw.HeaderMap {
   207  		if !strings.HasPrefix(k, http.TrailerPrefix) {
   208  			continue
   209  		}
   210  		if res.Trailer == nil {
   211  			res.Trailer = make(http.Header)
   212  		}
   213  		for _, v := range vv {
   214  			res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
   215  		}
   216  	}
   217  	return res
   218  }
   219  
   220  // parseContentLength trims whitespace from s and returns -1 if no value
   221  // is set, or the value if it's >= 0.
   222  //
   223  // This a modified version of same function found in net/http/transfer.go. This
   224  // one just ignores an invalid header.
   225  func parseContentLength(cl string) int64 {
   226  	cl = textproto.TrimString(cl)
   227  	if cl == "" {
   228  		return -1
   229  	}
   230  	n, err := strconv.ParseUint(cl, 10, 63)
   231  	if err != nil {
   232  		return -1
   233  	}
   234  	return int64(n)
   235  }