github.com/slayercat/go@v0.0.0-20170428012452-c51559813f61/src/net/http/httptest/recorder.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"bytes"
     9  	"io/ioutil"
    10  	"net/http"
    11  	"strconv"
    12  	"strings"
    13  )
    14  
    15  // ResponseRecorder is an implementation of http.ResponseWriter that
    16  // records its mutations for later inspection in tests.
    17  type ResponseRecorder struct {
    18  	// Code is the HTTP response code set by WriteHeader.
    19  	//
    20  	// Note that if a Handler never calls WriteHeader or Write,
    21  	// this might end up being 0, rather than the implicit
    22  	// http.StatusOK. To get the implicit value, use the Result
    23  	// method.
    24  	Code int
    25  
    26  	// HeaderMap contains the headers explicitly set by the Handler.
    27  	//
    28  	// To get the implicit headers set by the server (such as
    29  	// automatic Content-Type), use the Result method.
    30  	HeaderMap http.Header
    31  
    32  	// Body is the buffer to which the Handler's Write calls are sent.
    33  	// If nil, the Writes are silently discarded.
    34  	Body *bytes.Buffer
    35  
    36  	// Flushed is whether the Handler called Flush.
    37  	Flushed bool
    38  
    39  	result      *http.Response // cache of Result's return value
    40  	snapHeader  http.Header    // snapshot of HeaderMap at first Write
    41  	wroteHeader bool
    42  }
    43  
    44  // NewRecorder returns an initialized ResponseRecorder.
    45  func NewRecorder() *ResponseRecorder {
    46  	return &ResponseRecorder{
    47  		HeaderMap: make(http.Header),
    48  		Body:      new(bytes.Buffer),
    49  		Code:      200,
    50  	}
    51  }
    52  
    53  // DefaultRemoteAddr is the default remote address to return in RemoteAddr if
    54  // an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
    55  const DefaultRemoteAddr = "1.2.3.4"
    56  
    57  // Header returns the response headers.
    58  func (rw *ResponseRecorder) Header() http.Header {
    59  	m := rw.HeaderMap
    60  	if m == nil {
    61  		m = make(http.Header)
    62  		rw.HeaderMap = m
    63  	}
    64  	return m
    65  }
    66  
    67  // writeHeader writes a header if it was not written yet and
    68  // detects Content-Type if needed.
    69  //
    70  // bytes or str are the beginning of the response body.
    71  // We pass both to avoid unnecessarily generate garbage
    72  // in rw.WriteString which was created for performance reasons.
    73  // Non-nil bytes win.
    74  func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
    75  	if rw.wroteHeader {
    76  		return
    77  	}
    78  	if len(str) > 512 {
    79  		str = str[:512]
    80  	}
    81  
    82  	m := rw.Header()
    83  
    84  	_, hasType := m["Content-Type"]
    85  	hasTE := m.Get("Transfer-Encoding") != ""
    86  	if !hasType && !hasTE {
    87  		if b == nil {
    88  			b = []byte(str)
    89  		}
    90  		m.Set("Content-Type", http.DetectContentType(b))
    91  	}
    92  
    93  	rw.WriteHeader(200)
    94  }
    95  
    96  // Write always succeeds and writes to rw.Body, if not nil.
    97  func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
    98  	rw.writeHeader(buf, "")
    99  	if rw.Body != nil {
   100  		rw.Body.Write(buf)
   101  	}
   102  	return len(buf), nil
   103  }
   104  
   105  // WriteString always succeeds and writes to rw.Body, if not nil.
   106  func (rw *ResponseRecorder) WriteString(str string) (int, error) {
   107  	rw.writeHeader(nil, str)
   108  	if rw.Body != nil {
   109  		rw.Body.WriteString(str)
   110  	}
   111  	return len(str), nil
   112  }
   113  
   114  // WriteHeader sets rw.Code. After it is called, changing rw.Header
   115  // will not affect rw.HeaderMap.
   116  func (rw *ResponseRecorder) WriteHeader(code int) {
   117  	if rw.wroteHeader {
   118  		return
   119  	}
   120  	rw.Code = code
   121  	rw.wroteHeader = true
   122  	if rw.HeaderMap == nil {
   123  		rw.HeaderMap = make(http.Header)
   124  	}
   125  	rw.snapHeader = cloneHeader(rw.HeaderMap)
   126  }
   127  
   128  func cloneHeader(h http.Header) http.Header {
   129  	h2 := make(http.Header, len(h))
   130  	for k, vv := range h {
   131  		vv2 := make([]string, len(vv))
   132  		copy(vv2, vv)
   133  		h2[k] = vv2
   134  	}
   135  	return h2
   136  }
   137  
   138  // Flush sets rw.Flushed to true.
   139  func (rw *ResponseRecorder) Flush() {
   140  	if !rw.wroteHeader {
   141  		rw.WriteHeader(200)
   142  	}
   143  	rw.Flushed = true
   144  }
   145  
   146  // Result returns the response generated by the handler.
   147  //
   148  // The returned Response will have at least its StatusCode,
   149  // Header, Body, and optionally Trailer populated.
   150  // More fields may be populated in the future, so callers should
   151  // not DeepEqual the result in tests.
   152  //
   153  // The Response.Header is a snapshot of the headers at the time of the
   154  // first write call, or at the time of this call, if the handler never
   155  // did a write.
   156  //
   157  // The Response.Body is guaranteed to be non-nil and Body.Read call is
   158  // guaranteed to not return any error other than io.EOF.
   159  //
   160  // Result must only be called after the handler has finished running.
   161  func (rw *ResponseRecorder) Result() *http.Response {
   162  	if rw.result != nil {
   163  		return rw.result
   164  	}
   165  	if rw.snapHeader == nil {
   166  		rw.snapHeader = cloneHeader(rw.HeaderMap)
   167  	}
   168  	res := &http.Response{
   169  		Proto:      "HTTP/1.1",
   170  		ProtoMajor: 1,
   171  		ProtoMinor: 1,
   172  		StatusCode: rw.Code,
   173  		Header:     rw.snapHeader,
   174  	}
   175  	rw.result = res
   176  	if res.StatusCode == 0 {
   177  		res.StatusCode = 200
   178  	}
   179  	res.Status = http.StatusText(res.StatusCode)
   180  	if rw.Body != nil {
   181  		res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
   182  	}
   183  	res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
   184  
   185  	if trailers, ok := rw.snapHeader["Trailer"]; ok {
   186  		res.Trailer = make(http.Header, len(trailers))
   187  		for _, k := range trailers {
   188  			// TODO: use http2.ValidTrailerHeader, but we can't
   189  			// get at it easily because it's bundled into net/http
   190  			// unexported. This is good enough for now:
   191  			switch k {
   192  			case "Transfer-Encoding", "Content-Length", "Trailer":
   193  				// Ignore since forbidden by RFC 2616 14.40.
   194  				continue
   195  			}
   196  			k = http.CanonicalHeaderKey(k)
   197  			vv, ok := rw.HeaderMap[k]
   198  			if !ok {
   199  				continue
   200  			}
   201  			vv2 := make([]string, len(vv))
   202  			copy(vv2, vv)
   203  			res.Trailer[k] = vv2
   204  		}
   205  	}
   206  	for k, vv := range rw.HeaderMap {
   207  		if !strings.HasPrefix(k, http.TrailerPrefix) {
   208  			continue
   209  		}
   210  		if res.Trailer == nil {
   211  			res.Trailer = make(http.Header)
   212  		}
   213  		for _, v := range vv {
   214  			res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
   215  		}
   216  	}
   217  	return res
   218  }
   219  
   220  // parseContentLength trims whitespace from s and returns -1 if no value
   221  // is set, or the value if it's >= 0.
   222  //
   223  // This a modified version of same function found in net/http/transfer.go. This
   224  // one just ignores an invalid header.
   225  func parseContentLength(cl string) int64 {
   226  	cl = strings.TrimSpace(cl)
   227  	if cl == "" {
   228  		return -1
   229  	}
   230  	n, err := strconv.ParseInt(cl, 10, 64)
   231  	if err != nil {
   232  		return -1
   233  	}
   234  	return n
   235  }