github.com/xmidt-org/webpa-common@v1.11.9/xhttp/xhttptest/transactor.go (about)

     1  package xhttptest
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/textproto"
    10  	"net/url"
    11  	"strings"
    12  
    13  	"github.com/stretchr/testify/mock"
    14  )
    15  
    16  // ExpectedResponse is a tuple of the expected return values from transactor.Do.  This struct provides
    17  // a simple unit to build table-driven tests from.
    18  type ExpectedResponse struct {
    19  	StatusCode int
    20  	Body       []byte
    21  	Header     http.Header
    22  	Err        error
    23  }
    24  
    25  // TransactCall is a stretchr mock Call with some extra behavior to make mocking out HTTP client behavior easier
    26  type TransactCall struct {
    27  	*mock.Call
    28  }
    29  
    30  // RespondWith creates an (*http.Response, error) tuple from an ExpectedResponse.  If the Err field is nil,
    31  // an *http.Response is created from the other fields.  If the Err field is not nil, a nil *http.Response is used.
    32  func (dc *TransactCall) RespondWith(er ExpectedResponse) *TransactCall {
    33  	var response *http.Response
    34  	if er.Err == nil {
    35  		response = NewResponse(er.StatusCode, er.Body)
    36  		for key, values := range er.Header {
    37  			response.Header[key] = values
    38  		}
    39  	}
    40  
    41  	return dc.Respond(response, er.Err)
    42  }
    43  
    44  // Respond is a convenience for setting a Return(response, err)
    45  func (dc *TransactCall) Respond(response *http.Response, err error) *TransactCall {
    46  	dc.Return(response, err)
    47  	return dc
    48  }
    49  
    50  // MockTransactor is a stretchr mock for the Do method of an HTTP client or round tripper.
    51  // This mock extends the behavior of a stretchr mock in a few ways that make clientside
    52  // HTTP behavior easier to mock.
    53  //
    54  // This type implements the http.RoundTripper interface, and provides a Do method that can
    55  // implement a subset interface of http.Client.
    56  type MockTransactor struct {
    57  	mock.Mock
    58  }
    59  
    60  // Do is a mocked HTTP transaction call.  Use On or OnRequest to setup behaviors for this method.
    61  func (mt *MockTransactor) Do(request *http.Request) (*http.Response, error) {
    62  	// HACK: Because of the way Called works, there is a race condition involving the http.Request's Context object.
    63  	// Called performs a printf, which bypasses the context's mutex to produce the string.  We have to replace
    64  	// the context with a known, immutable value so that no race conditions occur.
    65  	arguments := mt.Called(request.WithContext(context.Background()))
    66  	response, _ := arguments.Get(0).(*http.Response)
    67  	return response, arguments.Error(1)
    68  }
    69  
    70  // RoundTrip is a mocked HTTP transaction call.  Use On or OnRoundTrip to setup behaviors for this method.
    71  func (mt *MockTransactor) RoundTrip(request *http.Request) (*http.Response, error) {
    72  	// HACK: Because of the way Called works, there is a race condition involving the http.Request's Context object.
    73  	// Called performs a printf, which bypasses the context's mutex to produce the string.  We have to replace
    74  	// the context with a known, immutable value so that no race conditions occur.
    75  	arguments := mt.Called(request.WithContext(context.Background()))
    76  	response, _ := arguments.Get(0).(*http.Response)
    77  	return response, arguments.Error(1)
    78  }
    79  
    80  // OnDo sets an On("Do", ...) with the given matchers for a request.  The returned Call has some
    81  // augmented behavior for setting responses.
    82  func (mt *MockTransactor) OnDo(matchers ...func(*http.Request) bool) *TransactCall {
    83  	call := mt.On("Do", mock.MatchedBy(func(candidate *http.Request) bool {
    84  		for _, matcher := range matchers {
    85  			if !matcher(candidate) {
    86  				return false
    87  			}
    88  		}
    89  
    90  		return true
    91  	}))
    92  
    93  	return &TransactCall{call}
    94  }
    95  
    96  // OnRoundTrip sets an On("Do", ...) with the given matchers for a request.  The returned Call has some
    97  // augmented behavior for setting responses.
    98  func (mt *MockTransactor) OnRoundTrip(matchers ...func(*http.Request) bool) *TransactCall {
    99  	call := mt.On("RoundTrip", mock.MatchedBy(func(candidate *http.Request) bool {
   100  		for _, matcher := range matchers {
   101  			if !matcher(candidate) {
   102  				return false
   103  			}
   104  		}
   105  
   106  		return true
   107  	}))
   108  
   109  	return &TransactCall{call}
   110  }
   111  
   112  // MatchMethod returns a request matcher that verifies each request has a specific method
   113  func MatchMethod(expected string) func(*http.Request) bool {
   114  	return func(r *http.Request) bool {
   115  		return strings.EqualFold(expected, r.Method)
   116  	}
   117  }
   118  
   119  // MatchURL returns a request matcher that verifies each request has an exact URL.
   120  func MatchURL(expected *url.URL) func(*http.Request) bool {
   121  	return func(r *http.Request) bool {
   122  		if expected == r.URL {
   123  			return true
   124  		}
   125  
   126  		if expected == nil || r.URL == nil {
   127  			return false
   128  		}
   129  
   130  		return *expected == *r.URL
   131  	}
   132  }
   133  
   134  // MatchURLString returns a request matcher that verifies the request's URL translates to the given string.
   135  func MatchURLString(expected string) func(*http.Request) bool {
   136  	return func(r *http.Request) bool {
   137  		if r.URL == nil {
   138  			return len(expected) == 0
   139  		}
   140  
   141  		return expected == r.URL.String()
   142  	}
   143  }
   144  
   145  // MatchBody returns a request matcher that verifies each request has an exact body.
   146  // The body is consumed, but then replaced so that downstream code can still access the body.
   147  func MatchBody(expected []byte) func(*http.Request) bool {
   148  	return func(r *http.Request) bool {
   149  		if r.Body == nil {
   150  			return len(expected) == 0
   151  		}
   152  
   153  		actual, err := ioutil.ReadAll(r.Body)
   154  		if err != nil {
   155  			panic(fmt.Errorf("Error while read request body for matching: %s", err))
   156  		}
   157  
   158  		// replace the body so other test code can reread it
   159  		r.Body = ioutil.NopCloser(bytes.NewReader(actual))
   160  
   161  		if len(actual) != len(expected) {
   162  			return false
   163  		}
   164  
   165  		for i := 0; i < len(actual); i++ {
   166  			if actual[i] != expected[i] {
   167  				return false
   168  			}
   169  		}
   170  
   171  		return true
   172  	}
   173  }
   174  
   175  func MatchBodyString(expected string) func(*http.Request) bool {
   176  	return MatchBody([]byte(expected))
   177  }
   178  
   179  // MatchHeader returns a request matcher that matches against a request header
   180  func MatchHeader(name, expected string) func(*http.Request) bool {
   181  	return func(r *http.Request) bool {
   182  		// allow for requests created by test code that instantiates the request directly
   183  		if r.Header == nil {
   184  			return false
   185  		}
   186  
   187  		values := r.Header[textproto.CanonicalMIMEHeaderKey(name)]
   188  		if len(values) == 0 {
   189  			return len(expected) == 0
   190  		}
   191  
   192  		for _, actual := range values {
   193  			if actual == expected {
   194  				return true
   195  			}
   196  		}
   197  
   198  		return false
   199  	}
   200  }