github.com/4ad/go@v0.0.0-20161219182952-69a12818b605/src/net/http/httptest/recorder.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"bytes"
     9  	"io/ioutil"
    10  	"net/http"
    11  )
    12  
    13  // ResponseRecorder is an implementation of http.ResponseWriter that
    14  // records its mutations for later inspection in tests.
    15  type ResponseRecorder struct {
    16  	Code      int           // the HTTP response code from WriteHeader
    17  	HeaderMap http.Header   // the HTTP response headers
    18  	Body      *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
    19  	Flushed   bool
    20  
    21  	result      *http.Response // cache of Result's return value
    22  	snapHeader  http.Header    // snapshot of HeaderMap at first Write
    23  	wroteHeader bool
    24  }
    25  
    26  // NewRecorder returns an initialized ResponseRecorder.
    27  func NewRecorder() *ResponseRecorder {
    28  	return &ResponseRecorder{
    29  		HeaderMap: make(http.Header),
    30  		Body:      new(bytes.Buffer),
    31  		Code:      200,
    32  	}
    33  }
    34  
    35  // DefaultRemoteAddr is the default remote address to return in RemoteAddr if
    36  // an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
    37  const DefaultRemoteAddr = "1.2.3.4"
    38  
    39  // Header returns the response headers.
    40  func (rw *ResponseRecorder) Header() http.Header {
    41  	m := rw.HeaderMap
    42  	if m == nil {
    43  		m = make(http.Header)
    44  		rw.HeaderMap = m
    45  	}
    46  	return m
    47  }
    48  
    49  // writeHeader writes a header if it was not written yet and
    50  // detects Content-Type if needed.
    51  //
    52  // bytes or str are the beginning of the response body.
    53  // We pass both to avoid unnecessarily generate garbage
    54  // in rw.WriteString which was created for performance reasons.
    55  // Non-nil bytes win.
    56  func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
    57  	if rw.wroteHeader {
    58  		return
    59  	}
    60  	if len(str) > 512 {
    61  		str = str[:512]
    62  	}
    63  
    64  	m := rw.Header()
    65  
    66  	_, hasType := m["Content-Type"]
    67  	hasTE := m.Get("Transfer-Encoding") != ""
    68  	if !hasType && !hasTE {
    69  		if b == nil {
    70  			b = []byte(str)
    71  		}
    72  		m.Set("Content-Type", http.DetectContentType(b))
    73  	}
    74  
    75  	rw.WriteHeader(200)
    76  }
    77  
    78  // Write always succeeds and writes to rw.Body, if not nil.
    79  func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
    80  	rw.writeHeader(buf, "")
    81  	if rw.Body != nil {
    82  		rw.Body.Write(buf)
    83  	}
    84  	return len(buf), nil
    85  }
    86  
    87  // WriteString always succeeds and writes to rw.Body, if not nil.
    88  func (rw *ResponseRecorder) WriteString(str string) (int, error) {
    89  	rw.writeHeader(nil, str)
    90  	if rw.Body != nil {
    91  		rw.Body.WriteString(str)
    92  	}
    93  	return len(str), nil
    94  }
    95  
    96  // WriteHeader sets rw.Code. After it is called, changing rw.Header
    97  // will not affect rw.HeaderMap.
    98  func (rw *ResponseRecorder) WriteHeader(code int) {
    99  	if rw.wroteHeader {
   100  		return
   101  	}
   102  	rw.Code = code
   103  	rw.wroteHeader = true
   104  	if rw.HeaderMap == nil {
   105  		rw.HeaderMap = make(http.Header)
   106  	}
   107  	rw.snapHeader = cloneHeader(rw.HeaderMap)
   108  }
   109  
   110  func cloneHeader(h http.Header) http.Header {
   111  	h2 := make(http.Header, len(h))
   112  	for k, vv := range h {
   113  		vv2 := make([]string, len(vv))
   114  		copy(vv2, vv)
   115  		h2[k] = vv2
   116  	}
   117  	return h2
   118  }
   119  
   120  // Flush sets rw.Flushed to true.
   121  func (rw *ResponseRecorder) Flush() {
   122  	if !rw.wroteHeader {
   123  		rw.WriteHeader(200)
   124  	}
   125  	rw.Flushed = true
   126  }
   127  
   128  // Result returns the response generated by the handler.
   129  //
   130  // The returned Response will have at least its StatusCode,
   131  // Header, Body, and optionally Trailer populated.
   132  // More fields may be populated in the future, so callers should
   133  // not DeepEqual the result in tests.
   134  //
   135  // The Response.Header is a snapshot of the headers at the time of the
   136  // first write call, or at the time of this call, if the handler never
   137  // did a write.
   138  //
   139  // Result must only be called after the handler has finished running.
   140  func (rw *ResponseRecorder) Result() *http.Response {
   141  	if rw.result != nil {
   142  		return rw.result
   143  	}
   144  	if rw.snapHeader == nil {
   145  		rw.snapHeader = cloneHeader(rw.HeaderMap)
   146  	}
   147  	res := &http.Response{
   148  		Proto:      "HTTP/1.1",
   149  		ProtoMajor: 1,
   150  		ProtoMinor: 1,
   151  		StatusCode: rw.Code,
   152  		Header:     rw.snapHeader,
   153  	}
   154  	rw.result = res
   155  	if res.StatusCode == 0 {
   156  		res.StatusCode = 200
   157  	}
   158  	res.Status = http.StatusText(res.StatusCode)
   159  	if rw.Body != nil {
   160  		res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
   161  	}
   162  
   163  	if trailers, ok := rw.snapHeader["Trailer"]; ok {
   164  		res.Trailer = make(http.Header, len(trailers))
   165  		for _, k := range trailers {
   166  			// TODO: use http2.ValidTrailerHeader, but we can't
   167  			// get at it easily because it's bundled into net/http
   168  			// unexported. This is good enough for now:
   169  			switch k {
   170  			case "Transfer-Encoding", "Content-Length", "Trailer":
   171  				// Ignore since forbidden by RFC 2616 14.40.
   172  				continue
   173  			}
   174  			k = http.CanonicalHeaderKey(k)
   175  			vv, ok := rw.HeaderMap[k]
   176  			if !ok {
   177  				continue
   178  			}
   179  			vv2 := make([]string, len(vv))
   180  			copy(vv2, vv)
   181  			res.Trailer[k] = vv2
   182  		}
   183  	}
   184  	return res
   185  }