go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/lhttp/client_test.go (about) 1 // Copyright 2015 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package lhttp 16 17 import ( 18 "bytes" 19 "context" 20 "fmt" 21 "io" 22 "net/http" 23 "net/http/httptest" 24 "sync" 25 "testing" 26 27 "go.chromium.org/luci/common/retry" 28 29 . "github.com/smartystreets/goconvey/convey" 30 ) 31 32 func httpReqGen(method, url string, body []byte) RequestGen { 33 return func() (*http.Request, error) { 34 var bodyReader io.Reader 35 if body != nil { 36 bodyReader = bytes.NewReader(body) 37 } 38 return http.NewRequest("GET", url, bodyReader) 39 } 40 } 41 42 func TestNewRequestGET(t *testing.T) { 43 Convey(`HTTP GET requests should be handled correctly.`, t, func(c C) { 44 ctx := context.Background() 45 46 // First call returns HTTP 500, second succeeds. 47 serverCalls := 0 48 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 49 serverCalls++ 50 content, err := io.ReadAll(r.Body) 51 c.So(err, ShouldBeNil) 52 c.So(content, ShouldResemble, []byte{}) 53 if serverCalls == 1 { 54 w.WriteHeader(500) 55 } else { 56 fmt.Fprintf(w, "Hello, client\n") 57 } 58 })) 59 defer ts.Close() 60 61 httpReq := httpReqGen("GET", ts.URL, nil) 62 63 clientCalls := 0 64 clientReq := NewRequest(ctx, http.DefaultClient, fast, httpReq, func(resp *http.Response) error { 65 clientCalls++ 66 content, err := io.ReadAll(resp.Body) 67 So(err, ShouldBeNil) 68 So(string(content), ShouldResemble, "Hello, client\n") 69 So(resp.Body.Close(), ShouldBeNil) 70 return nil 71 }, nil) 72 73 status, err := clientReq() 74 So(err, ShouldBeNil) 75 So(status, ShouldResemble, 200) 76 So(serverCalls, ShouldResemble, 2) 77 So(clientCalls, ShouldResemble, 1) 78 }) 79 } 80 81 func TestNewRequestPOST(t *testing.T) { 82 Convey(`HTTP POST requests should be handled correctly.`, t, func(c C) { 83 ctx := context.Background() 84 85 // First call returns HTTP 500, second succeeds. 86 serverCalls := 0 87 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 88 serverCalls++ 89 content, err := io.ReadAll(r.Body) 90 c.So(err, ShouldBeNil) 91 // The same data is sent twice. 92 c.So(string(content), ShouldResemble, "foo bar") 93 if serverCalls == 1 { 94 w.WriteHeader(500) 95 } else { 96 fmt.Fprintf(w, "Hello, client\n") 97 } 98 })) 99 defer ts.Close() 100 101 httpReq := httpReqGen("POST", ts.URL, []byte("foo bar")) 102 103 clientCalls := 0 104 clientReq := NewRequest(ctx, http.DefaultClient, fast, httpReq, func(resp *http.Response) error { 105 clientCalls++ 106 content, err := io.ReadAll(resp.Body) 107 So(err, ShouldBeNil) 108 So(string(content), ShouldResemble, "Hello, client\n") 109 So(resp.Body.Close(), ShouldBeNil) 110 return nil 111 }, nil) 112 113 status, err := clientReq() 114 So(err, ShouldBeNil) 115 So(status, ShouldResemble, 200) 116 So(serverCalls, ShouldResemble, 2) 117 So(clientCalls, ShouldResemble, 1) 118 }) 119 } 120 121 func TestNewRequestGETFail(t *testing.T) { 122 Convey(`HTTP GET requests should handle failure successfully.`, t, func() { 123 ctx := context.Background() 124 125 serverCalls := 0 126 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 127 serverCalls++ 128 w.WriteHeader(500) 129 })) 130 defer ts.Close() 131 132 httpReq := httpReqGen("GET", ts.URL, nil) 133 134 clientReq := NewRequest(ctx, http.DefaultClient, fast, httpReq, func(resp *http.Response) error { 135 t.Fail() 136 return nil 137 }, nil) 138 139 status, err := clientReq() 140 So(err.Error(), ShouldResemble, "gave up after 4 attempts: http request failed: Internal Server Error (HTTP 500)") 141 So(status, ShouldResemble, 500) 142 }) 143 } 144 145 func TestNewRequestDefaultFactory(t *testing.T) { 146 // Test that the default factory (rFn == nil) only retries for transient 147 // HTTP errors. 148 testCases := []struct { 149 statusCode int // The status code to return (the first 2 times). 150 path string // Request path, if any. 151 wantErr bool // Whether we want NewRequest to return an error. 152 wantCalls int // The total number of HTTP requests expected. 153 }{ 154 // 200, passes immediately. 155 {statusCode: 200, wantErr: false, wantCalls: 1}, 156 // Transient HTTP error codes that will retry. 157 {statusCode: 408, wantErr: false, wantCalls: 3}, 158 {statusCode: 500, wantErr: false, wantCalls: 3}, 159 {statusCode: 503, wantErr: false, wantCalls: 3}, 160 // Immediate failure codes. 161 {statusCode: 403, wantErr: true, wantCalls: 1}, 162 {statusCode: 404, wantErr: true, wantCalls: 1}, 163 } 164 165 ctx := context.Background() 166 167 for _, tc := range testCases { 168 tc := tc 169 t.Run(fmt.Sprintf("Status code %d, path %q", tc.statusCode, tc.path), func(t *testing.T) { 170 t.Parallel() 171 serverCalls := 0 172 ts := httptest.NewServer( 173 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 174 defer r.Body.Close() 175 serverCalls++ 176 if serverCalls <= 2 { 177 w.WriteHeader(tc.statusCode) 178 } 179 fmt.Fprintf(w, "Hello World!\n") 180 })) 181 defer ts.Close() 182 183 httpReq := httpReqGen("GET", ts.URL+tc.path, nil) 184 req := NewRequest(ctx, http.DefaultClient, nil, httpReq, func(resp *http.Response) error { 185 return resp.Body.Close() 186 }, nil) 187 188 _, err := req() 189 if err == nil && tc.wantErr { 190 t.Error("req returned nil error, wanted an error") 191 } else if err != nil && !tc.wantErr { 192 t.Errorf("req returned err %v, wanted nil", err) 193 } 194 if got, want := serverCalls, tc.wantCalls; got != want { 195 t.Errorf("total server calls; got %d, want %d", got, want) 196 } 197 }) 198 } 199 } 200 201 func TestNewRequestClosesBody(t *testing.T) { 202 ctx := context.Background() 203 serverCalls := 0 204 205 // Return a 500 for the first 2 requests. 206 ts := httptest.NewServer( 207 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 208 defer r.Body.Close() 209 serverCalls++ 210 if serverCalls <= 2 { 211 w.WriteHeader(500) 212 } 213 fmt.Fprintf(w, "Hello World!\n") 214 })) 215 defer ts.Close() 216 217 rt := &trackingRoundTripper{RoundTripper: http.DefaultTransport} 218 hc := &http.Client{Transport: rt} 219 httpReq := httpReqGen("GET", ts.URL, nil) 220 221 clientCalls := 0 222 var lastResp *http.Response 223 req := NewRequest(ctx, hc, fast, httpReq, func(resp *http.Response) error { 224 clientCalls++ 225 lastResp = resp 226 return resp.Body.Close() 227 }, nil) 228 229 status, err := req() 230 if err != nil { 231 t.Fatalf("req returned err %v, want nil", err) 232 } 233 if got, want := status, http.StatusOK; got != want { 234 t.Errorf("req returned status %d, want %d", got, want) 235 } 236 237 // We expect only one client call, but three requests through to the server. 238 if got, want := clientCalls, 1; got != want { 239 t.Errorf("handler callback invoked %d times, want %d", got, want) 240 } 241 if got, want := len(rt.Responses), 3; got != want { 242 t.Errorf("len(Responses) = %d, want %d", got, want) 243 } 244 245 // Check that the last response is the one we handled, and that all the bodies 246 // were closed. 247 if got, want := lastResp, rt.Responses[2]; got != want { 248 t.Errorf("Last Response did not match Response in handler callback.\nGot: %v\nWant: %v", got, want) 249 } 250 for i, resp := range rt.Responses { 251 rc := resp.Body.(*trackingReadCloser) 252 if !rc.Closed { 253 t.Errorf("Responses[%d].Body was not closed", i) 254 } 255 } 256 } 257 258 // trackingRoundTripper wraps an http.RoundTripper, keeping track of any 259 // returned Responses. Each response's Body, when set, is wrapped with a 260 // trackingReadCloser. 261 type trackingRoundTripper struct { 262 http.RoundTripper 263 264 mu sync.Mutex 265 Responses []*http.Response 266 } 267 268 func (t *trackingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 269 resp, err := t.RoundTripper.RoundTrip(req) 270 if resp != nil && resp.Body != nil { 271 resp.Body = &trackingReadCloser{ReadCloser: resp.Body} 272 } 273 t.mu.Lock() 274 defer t.mu.Unlock() 275 t.Responses = append(t.Responses, resp) 276 return resp, err 277 } 278 279 // trackingReadCloser wraps an io.ReadCloser, keeping track of whether Closed was 280 // called. 281 type trackingReadCloser struct { 282 io.ReadCloser 283 Closed bool 284 } 285 286 func (t *trackingReadCloser) Close() error { 287 t.Closed = true 288 return t.ReadCloser.Close() 289 } 290 291 // Private details. 292 293 func fast() retry.Iterator { 294 return &retry.Limited{Retries: 3} 295 }