github.com/yukk001/go1.10.8@v0.0.0-20190813125351-6df2d3982e20/src/net/http/httptest/recorder.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"io/ioutil"
    11  	"net/http"
    12  	"strconv"
    13  	"strings"
    14  )
    15  
    16  // ResponseRecorder is an implementation of http.ResponseWriter that
    17  // records its mutations for later inspection in tests.
    18  type ResponseRecorder struct {
    19  	// Code is the HTTP response code set by WriteHeader.
    20  	//
    21  	// Note that if a Handler never calls WriteHeader or Write,
    22  	// this might end up being 0, rather than the implicit
    23  	// http.StatusOK. To get the implicit value, use the Result
    24  	// method.
    25  	Code int
    26  
    27  	// HeaderMap contains the headers explicitly set by the Handler.
    28  	//
    29  	// To get the implicit headers set by the server (such as
    30  	// automatic Content-Type), use the Result method.
    31  	HeaderMap http.Header
    32  
    33  	// Body is the buffer to which the Handler's Write calls are sent.
    34  	// If nil, the Writes are silently discarded.
    35  	Body *bytes.Buffer
    36  
    37  	// Flushed is whether the Handler called Flush.
    38  	Flushed bool
    39  
    40  	result      *http.Response // cache of Result's return value
    41  	snapHeader  http.Header    // snapshot of HeaderMap at first Write
    42  	wroteHeader bool
    43  }
    44  
    45  // NewRecorder returns an initialized ResponseRecorder.
    46  func NewRecorder() *ResponseRecorder {
    47  	return &ResponseRecorder{
    48  		HeaderMap: make(http.Header),
    49  		Body:      new(bytes.Buffer),
    50  		Code:      200,
    51  	}
    52  }
    53  
    54  // DefaultRemoteAddr is the default remote address to return in RemoteAddr if
    55  // an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
    56  const DefaultRemoteAddr = "1.2.3.4"
    57  
    58  // Header returns the response headers.
    59  func (rw *ResponseRecorder) Header() http.Header {
    60  	m := rw.HeaderMap
    61  	if m == nil {
    62  		m = make(http.Header)
    63  		rw.HeaderMap = m
    64  	}
    65  	return m
    66  }
    67  
    68  // writeHeader writes a header if it was not written yet and
    69  // detects Content-Type if needed.
    70  //
    71  // bytes or str are the beginning of the response body.
    72  // We pass both to avoid unnecessarily generate garbage
    73  // in rw.WriteString which was created for performance reasons.
    74  // Non-nil bytes win.
    75  func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
    76  	if rw.wroteHeader {
    77  		return
    78  	}
    79  	if len(str) > 512 {
    80  		str = str[:512]
    81  	}
    82  
    83  	m := rw.Header()
    84  
    85  	_, hasType := m["Content-Type"]
    86  	hasTE := m.Get("Transfer-Encoding") != ""
    87  	if !hasType && !hasTE {
    88  		if b == nil {
    89  			b = []byte(str)
    90  		}
    91  		m.Set("Content-Type", http.DetectContentType(b))
    92  	}
    93  
    94  	rw.WriteHeader(200)
    95  }
    96  
    97  // Write always succeeds and writes to rw.Body, if not nil.
    98  func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
    99  	rw.writeHeader(buf, "")
   100  	if rw.Body != nil {
   101  		rw.Body.Write(buf)
   102  	}
   103  	return len(buf), nil
   104  }
   105  
   106  // WriteString always succeeds and writes to rw.Body, if not nil.
   107  func (rw *ResponseRecorder) WriteString(str string) (int, error) {
   108  	rw.writeHeader(nil, str)
   109  	if rw.Body != nil {
   110  		rw.Body.WriteString(str)
   111  	}
   112  	return len(str), nil
   113  }
   114  
   115  // WriteHeader sets rw.Code. After it is called, changing rw.Header
   116  // will not affect rw.HeaderMap.
   117  func (rw *ResponseRecorder) WriteHeader(code int) {
   118  	if rw.wroteHeader {
   119  		return
   120  	}
   121  	rw.Code = code
   122  	rw.wroteHeader = true
   123  	if rw.HeaderMap == nil {
   124  		rw.HeaderMap = make(http.Header)
   125  	}
   126  	rw.snapHeader = cloneHeader(rw.HeaderMap)
   127  }
   128  
   129  func cloneHeader(h http.Header) http.Header {
   130  	h2 := make(http.Header, len(h))
   131  	for k, vv := range h {
   132  		vv2 := make([]string, len(vv))
   133  		copy(vv2, vv)
   134  		h2[k] = vv2
   135  	}
   136  	return h2
   137  }
   138  
   139  // Flush sets rw.Flushed to true.
   140  func (rw *ResponseRecorder) Flush() {
   141  	if !rw.wroteHeader {
   142  		rw.WriteHeader(200)
   143  	}
   144  	rw.Flushed = true
   145  }
   146  
   147  // Result returns the response generated by the handler.
   148  //
   149  // The returned Response will have at least its StatusCode,
   150  // Header, Body, and optionally Trailer populated.
   151  // More fields may be populated in the future, so callers should
   152  // not DeepEqual the result in tests.
   153  //
   154  // The Response.Header is a snapshot of the headers at the time of the
   155  // first write call, or at the time of this call, if the handler never
   156  // did a write.
   157  //
   158  // The Response.Body is guaranteed to be non-nil and Body.Read call is
   159  // guaranteed to not return any error other than io.EOF.
   160  //
   161  // Result must only be called after the handler has finished running.
   162  func (rw *ResponseRecorder) Result() *http.Response {
   163  	if rw.result != nil {
   164  		return rw.result
   165  	}
   166  	if rw.snapHeader == nil {
   167  		rw.snapHeader = cloneHeader(rw.HeaderMap)
   168  	}
   169  	res := &http.Response{
   170  		Proto:      "HTTP/1.1",
   171  		ProtoMajor: 1,
   172  		ProtoMinor: 1,
   173  		StatusCode: rw.Code,
   174  		Header:     rw.snapHeader,
   175  	}
   176  	rw.result = res
   177  	if res.StatusCode == 0 {
   178  		res.StatusCode = 200
   179  	}
   180  	res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
   181  	if rw.Body != nil {
   182  		res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
   183  	}
   184  	res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
   185  
   186  	if trailers, ok := rw.snapHeader["Trailer"]; ok {
   187  		res.Trailer = make(http.Header, len(trailers))
   188  		for _, k := range trailers {
   189  			// TODO: use http2.ValidTrailerHeader, but we can't
   190  			// get at it easily because it's bundled into net/http
   191  			// unexported. This is good enough for now:
   192  			switch k {
   193  			case "Transfer-Encoding", "Content-Length", "Trailer":
   194  				// Ignore since forbidden by RFC 2616 14.40.
   195  				continue
   196  			}
   197  			k = http.CanonicalHeaderKey(k)
   198  			vv, ok := rw.HeaderMap[k]
   199  			if !ok {
   200  				continue
   201  			}
   202  			vv2 := make([]string, len(vv))
   203  			copy(vv2, vv)
   204  			res.Trailer[k] = vv2
   205  		}
   206  	}
   207  	for k, vv := range rw.HeaderMap {
   208  		if !strings.HasPrefix(k, http.TrailerPrefix) {
   209  			continue
   210  		}
   211  		if res.Trailer == nil {
   212  			res.Trailer = make(http.Header)
   213  		}
   214  		for _, v := range vv {
   215  			res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
   216  		}
   217  	}
   218  	return res
   219  }
   220  
   221  // parseContentLength trims whitespace from s and returns -1 if no value
   222  // is set, or the value if it's >= 0.
   223  //
   224  // This a modified version of same function found in net/http/transfer.go. This
   225  // one just ignores an invalid header.
   226  func parseContentLength(cl string) int64 {
   227  	cl = strings.TrimSpace(cl)
   228  	if cl == "" {
   229  		return -1
   230  	}
   231  	n, err := strconv.ParseInt(cl, 10, 64)
   232  	if err != nil {
   233  		return -1
   234  	}
   235  	return n
   236  }