github.com/xmidt-org/webpa-common@v1.11.9/xhttp/xhttptest/transactor.go (about) 1 package xhttptest 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io/ioutil" 8 "net/http" 9 "net/textproto" 10 "net/url" 11 "strings" 12 13 "github.com/stretchr/testify/mock" 14 ) 15 16 // ExpectedResponse is a tuple of the expected return values from transactor.Do. This struct provides 17 // a simple unit to build table-driven tests from. 18 type ExpectedResponse struct { 19 StatusCode int 20 Body []byte 21 Header http.Header 22 Err error 23 } 24 25 // TransactCall is a stretchr mock Call with some extra behavior to make mocking out HTTP client behavior easier 26 type TransactCall struct { 27 *mock.Call 28 } 29 30 // RespondWith creates an (*http.Response, error) tuple from an ExpectedResponse. If the Err field is nil, 31 // an *http.Response is created from the other fields. If the Err field is not nil, a nil *http.Response is used. 32 func (dc *TransactCall) RespondWith(er ExpectedResponse) *TransactCall { 33 var response *http.Response 34 if er.Err == nil { 35 response = NewResponse(er.StatusCode, er.Body) 36 for key, values := range er.Header { 37 response.Header[key] = values 38 } 39 } 40 41 return dc.Respond(response, er.Err) 42 } 43 44 // Respond is a convenience for setting a Return(response, err) 45 func (dc *TransactCall) Respond(response *http.Response, err error) *TransactCall { 46 dc.Return(response, err) 47 return dc 48 } 49 50 // MockTransactor is a stretchr mock for the Do method of an HTTP client or round tripper. 51 // This mock extends the behavior of a stretchr mock in a few ways that make clientside 52 // HTTP behavior easier to mock. 53 // 54 // This type implements the http.RoundTripper interface, and provides a Do method that can 55 // implement a subset interface of http.Client. 56 type MockTransactor struct { 57 mock.Mock 58 } 59 60 // Do is a mocked HTTP transaction call. Use On or OnRequest to setup behaviors for this method. 61 func (mt *MockTransactor) Do(request *http.Request) (*http.Response, error) { 62 // HACK: Because of the way Called works, there is a race condition involving the http.Request's Context object. 63 // Called performs a printf, which bypasses the context's mutex to produce the string. We have to replace 64 // the context with a known, immutable value so that no race conditions occur. 65 arguments := mt.Called(request.WithContext(context.Background())) 66 response, _ := arguments.Get(0).(*http.Response) 67 return response, arguments.Error(1) 68 } 69 70 // RoundTrip is a mocked HTTP transaction call. Use On or OnRoundTrip to setup behaviors for this method. 71 func (mt *MockTransactor) RoundTrip(request *http.Request) (*http.Response, error) { 72 // HACK: Because of the way Called works, there is a race condition involving the http.Request's Context object. 73 // Called performs a printf, which bypasses the context's mutex to produce the string. We have to replace 74 // the context with a known, immutable value so that no race conditions occur. 75 arguments := mt.Called(request.WithContext(context.Background())) 76 response, _ := arguments.Get(0).(*http.Response) 77 return response, arguments.Error(1) 78 } 79 80 // OnDo sets an On("Do", ...) with the given matchers for a request. The returned Call has some 81 // augmented behavior for setting responses. 82 func (mt *MockTransactor) OnDo(matchers ...func(*http.Request) bool) *TransactCall { 83 call := mt.On("Do", mock.MatchedBy(func(candidate *http.Request) bool { 84 for _, matcher := range matchers { 85 if !matcher(candidate) { 86 return false 87 } 88 } 89 90 return true 91 })) 92 93 return &TransactCall{call} 94 } 95 96 // OnRoundTrip sets an On("Do", ...) with the given matchers for a request. The returned Call has some 97 // augmented behavior for setting responses. 98 func (mt *MockTransactor) OnRoundTrip(matchers ...func(*http.Request) bool) *TransactCall { 99 call := mt.On("RoundTrip", mock.MatchedBy(func(candidate *http.Request) bool { 100 for _, matcher := range matchers { 101 if !matcher(candidate) { 102 return false 103 } 104 } 105 106 return true 107 })) 108 109 return &TransactCall{call} 110 } 111 112 // MatchMethod returns a request matcher that verifies each request has a specific method 113 func MatchMethod(expected string) func(*http.Request) bool { 114 return func(r *http.Request) bool { 115 return strings.EqualFold(expected, r.Method) 116 } 117 } 118 119 // MatchURL returns a request matcher that verifies each request has an exact URL. 120 func MatchURL(expected *url.URL) func(*http.Request) bool { 121 return func(r *http.Request) bool { 122 if expected == r.URL { 123 return true 124 } 125 126 if expected == nil || r.URL == nil { 127 return false 128 } 129 130 return *expected == *r.URL 131 } 132 } 133 134 // MatchURLString returns a request matcher that verifies the request's URL translates to the given string. 135 func MatchURLString(expected string) func(*http.Request) bool { 136 return func(r *http.Request) bool { 137 if r.URL == nil { 138 return len(expected) == 0 139 } 140 141 return expected == r.URL.String() 142 } 143 } 144 145 // MatchBody returns a request matcher that verifies each request has an exact body. 146 // The body is consumed, but then replaced so that downstream code can still access the body. 147 func MatchBody(expected []byte) func(*http.Request) bool { 148 return func(r *http.Request) bool { 149 if r.Body == nil { 150 return len(expected) == 0 151 } 152 153 actual, err := ioutil.ReadAll(r.Body) 154 if err != nil { 155 panic(fmt.Errorf("Error while read request body for matching: %s", err)) 156 } 157 158 // replace the body so other test code can reread it 159 r.Body = ioutil.NopCloser(bytes.NewReader(actual)) 160 161 if len(actual) != len(expected) { 162 return false 163 } 164 165 for i := 0; i < len(actual); i++ { 166 if actual[i] != expected[i] { 167 return false 168 } 169 } 170 171 return true 172 } 173 } 174 175 func MatchBodyString(expected string) func(*http.Request) bool { 176 return MatchBody([]byte(expected)) 177 } 178 179 // MatchHeader returns a request matcher that matches against a request header 180 func MatchHeader(name, expected string) func(*http.Request) bool { 181 return func(r *http.Request) bool { 182 // allow for requests created by test code that instantiates the request directly 183 if r.Header == nil { 184 return false 185 } 186 187 values := r.Header[textproto.CanonicalMIMEHeaderKey(name)] 188 if len(values) == 0 { 189 return len(expected) == 0 190 } 191 192 for _, actual := range values { 193 if actual == expected { 194 return true 195 } 196 } 197 198 return false 199 } 200 }