go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/lhttp/client_test.go (about)

     1  // Copyright 2015 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package lhttp
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"fmt"
    21  	"io"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"sync"
    25  	"testing"
    26  
    27  	"go.chromium.org/luci/common/retry"
    28  
    29  	. "github.com/smartystreets/goconvey/convey"
    30  )
    31  
    32  func httpReqGen(method, url string, body []byte) RequestGen {
    33  	return func() (*http.Request, error) {
    34  		var bodyReader io.Reader
    35  		if body != nil {
    36  			bodyReader = bytes.NewReader(body)
    37  		}
    38  		return http.NewRequest("GET", url, bodyReader)
    39  	}
    40  }
    41  
    42  func TestNewRequestGET(t *testing.T) {
    43  	Convey(`HTTP GET requests should be handled correctly.`, t, func(c C) {
    44  		ctx := context.Background()
    45  
    46  		// First call returns HTTP 500, second succeeds.
    47  		serverCalls := 0
    48  		ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    49  			serverCalls++
    50  			content, err := io.ReadAll(r.Body)
    51  			c.So(err, ShouldBeNil)
    52  			c.So(content, ShouldResemble, []byte{})
    53  			if serverCalls == 1 {
    54  				w.WriteHeader(500)
    55  			} else {
    56  				fmt.Fprintf(w, "Hello, client\n")
    57  			}
    58  		}))
    59  		defer ts.Close()
    60  
    61  		httpReq := httpReqGen("GET", ts.URL, nil)
    62  
    63  		clientCalls := 0
    64  		clientReq := NewRequest(ctx, http.DefaultClient, fast, httpReq, func(resp *http.Response) error {
    65  			clientCalls++
    66  			content, err := io.ReadAll(resp.Body)
    67  			So(err, ShouldBeNil)
    68  			So(string(content), ShouldResemble, "Hello, client\n")
    69  			So(resp.Body.Close(), ShouldBeNil)
    70  			return nil
    71  		}, nil)
    72  
    73  		status, err := clientReq()
    74  		So(err, ShouldBeNil)
    75  		So(status, ShouldResemble, 200)
    76  		So(serverCalls, ShouldResemble, 2)
    77  		So(clientCalls, ShouldResemble, 1)
    78  	})
    79  }
    80  
    81  func TestNewRequestPOST(t *testing.T) {
    82  	Convey(`HTTP POST requests should be handled correctly.`, t, func(c C) {
    83  		ctx := context.Background()
    84  
    85  		// First call returns HTTP 500, second succeeds.
    86  		serverCalls := 0
    87  		ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    88  			serverCalls++
    89  			content, err := io.ReadAll(r.Body)
    90  			c.So(err, ShouldBeNil)
    91  			// The same data is sent twice.
    92  			c.So(string(content), ShouldResemble, "foo bar")
    93  			if serverCalls == 1 {
    94  				w.WriteHeader(500)
    95  			} else {
    96  				fmt.Fprintf(w, "Hello, client\n")
    97  			}
    98  		}))
    99  		defer ts.Close()
   100  
   101  		httpReq := httpReqGen("POST", ts.URL, []byte("foo bar"))
   102  
   103  		clientCalls := 0
   104  		clientReq := NewRequest(ctx, http.DefaultClient, fast, httpReq, func(resp *http.Response) error {
   105  			clientCalls++
   106  			content, err := io.ReadAll(resp.Body)
   107  			So(err, ShouldBeNil)
   108  			So(string(content), ShouldResemble, "Hello, client\n")
   109  			So(resp.Body.Close(), ShouldBeNil)
   110  			return nil
   111  		}, nil)
   112  
   113  		status, err := clientReq()
   114  		So(err, ShouldBeNil)
   115  		So(status, ShouldResemble, 200)
   116  		So(serverCalls, ShouldResemble, 2)
   117  		So(clientCalls, ShouldResemble, 1)
   118  	})
   119  }
   120  
   121  func TestNewRequestGETFail(t *testing.T) {
   122  	Convey(`HTTP GET requests should handle failure successfully.`, t, func() {
   123  		ctx := context.Background()
   124  
   125  		serverCalls := 0
   126  		ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   127  			serverCalls++
   128  			w.WriteHeader(500)
   129  		}))
   130  		defer ts.Close()
   131  
   132  		httpReq := httpReqGen("GET", ts.URL, nil)
   133  
   134  		clientReq := NewRequest(ctx, http.DefaultClient, fast, httpReq, func(resp *http.Response) error {
   135  			t.Fail()
   136  			return nil
   137  		}, nil)
   138  
   139  		status, err := clientReq()
   140  		So(err.Error(), ShouldResemble, "gave up after 4 attempts: http request failed: Internal Server Error (HTTP 500)")
   141  		So(status, ShouldResemble, 500)
   142  	})
   143  }
   144  
   145  func TestNewRequestDefaultFactory(t *testing.T) {
   146  	// Test that the default factory (rFn == nil) only retries for transient
   147  	// HTTP errors.
   148  	testCases := []struct {
   149  		statusCode int    // The status code to return (the first 2 times).
   150  		path       string // Request path, if any.
   151  		wantErr    bool   // Whether we want NewRequest to return an error.
   152  		wantCalls  int    // The total number of HTTP requests expected.
   153  	}{
   154  		// 200, passes immediately.
   155  		{statusCode: 200, wantErr: false, wantCalls: 1},
   156  		// Transient HTTP error codes that will retry.
   157  		{statusCode: 408, wantErr: false, wantCalls: 3},
   158  		{statusCode: 500, wantErr: false, wantCalls: 3},
   159  		{statusCode: 503, wantErr: false, wantCalls: 3},
   160  		// Immediate failure codes.
   161  		{statusCode: 403, wantErr: true, wantCalls: 1},
   162  		{statusCode: 404, wantErr: true, wantCalls: 1},
   163  	}
   164  
   165  	ctx := context.Background()
   166  
   167  	for _, tc := range testCases {
   168  		tc := tc
   169  		t.Run(fmt.Sprintf("Status code %d, path %q", tc.statusCode, tc.path), func(t *testing.T) {
   170  			t.Parallel()
   171  			serverCalls := 0
   172  			ts := httptest.NewServer(
   173  				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   174  					defer r.Body.Close()
   175  					serverCalls++
   176  					if serverCalls <= 2 {
   177  						w.WriteHeader(tc.statusCode)
   178  					}
   179  					fmt.Fprintf(w, "Hello World!\n")
   180  				}))
   181  			defer ts.Close()
   182  
   183  			httpReq := httpReqGen("GET", ts.URL+tc.path, nil)
   184  			req := NewRequest(ctx, http.DefaultClient, nil, httpReq, func(resp *http.Response) error {
   185  				return resp.Body.Close()
   186  			}, nil)
   187  
   188  			_, err := req()
   189  			if err == nil && tc.wantErr {
   190  				t.Error("req returned nil error, wanted an error")
   191  			} else if err != nil && !tc.wantErr {
   192  				t.Errorf("req returned err %v, wanted nil", err)
   193  			}
   194  			if got, want := serverCalls, tc.wantCalls; got != want {
   195  				t.Errorf("total server calls; got %d, want %d", got, want)
   196  			}
   197  		})
   198  	}
   199  }
   200  
   201  func TestNewRequestClosesBody(t *testing.T) {
   202  	ctx := context.Background()
   203  	serverCalls := 0
   204  
   205  	// Return a 500 for the first 2 requests.
   206  	ts := httptest.NewServer(
   207  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   208  			defer r.Body.Close()
   209  			serverCalls++
   210  			if serverCalls <= 2 {
   211  				w.WriteHeader(500)
   212  			}
   213  			fmt.Fprintf(w, "Hello World!\n")
   214  		}))
   215  	defer ts.Close()
   216  
   217  	rt := &trackingRoundTripper{RoundTripper: http.DefaultTransport}
   218  	hc := &http.Client{Transport: rt}
   219  	httpReq := httpReqGen("GET", ts.URL, nil)
   220  
   221  	clientCalls := 0
   222  	var lastResp *http.Response
   223  	req := NewRequest(ctx, hc, fast, httpReq, func(resp *http.Response) error {
   224  		clientCalls++
   225  		lastResp = resp
   226  		return resp.Body.Close()
   227  	}, nil)
   228  
   229  	status, err := req()
   230  	if err != nil {
   231  		t.Fatalf("req returned err %v, want nil", err)
   232  	}
   233  	if got, want := status, http.StatusOK; got != want {
   234  		t.Errorf("req returned status %d, want %d", got, want)
   235  	}
   236  
   237  	// We expect only one client call, but three requests through to the server.
   238  	if got, want := clientCalls, 1; got != want {
   239  		t.Errorf("handler callback invoked %d times, want %d", got, want)
   240  	}
   241  	if got, want := len(rt.Responses), 3; got != want {
   242  		t.Errorf("len(Responses) = %d, want %d", got, want)
   243  	}
   244  
   245  	// Check that the last response is the one we handled, and that all the bodies
   246  	// were closed.
   247  	if got, want := lastResp, rt.Responses[2]; got != want {
   248  		t.Errorf("Last Response did not match Response in handler callback.\nGot:  %v\nWant: %v", got, want)
   249  	}
   250  	for i, resp := range rt.Responses {
   251  		rc := resp.Body.(*trackingReadCloser)
   252  		if !rc.Closed {
   253  			t.Errorf("Responses[%d].Body was not closed", i)
   254  		}
   255  	}
   256  }
   257  
   258  // trackingRoundTripper wraps an http.RoundTripper, keeping track of any
   259  // returned Responses. Each response's Body, when set, is wrapped with a
   260  // trackingReadCloser.
   261  type trackingRoundTripper struct {
   262  	http.RoundTripper
   263  
   264  	mu        sync.Mutex
   265  	Responses []*http.Response
   266  }
   267  
   268  func (t *trackingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   269  	resp, err := t.RoundTripper.RoundTrip(req)
   270  	if resp != nil && resp.Body != nil {
   271  		resp.Body = &trackingReadCloser{ReadCloser: resp.Body}
   272  	}
   273  	t.mu.Lock()
   274  	defer t.mu.Unlock()
   275  	t.Responses = append(t.Responses, resp)
   276  	return resp, err
   277  }
   278  
   279  // trackingReadCloser wraps an io.ReadCloser, keeping track of whether Closed was
   280  // called.
   281  type trackingReadCloser struct {
   282  	io.ReadCloser
   283  	Closed bool
   284  }
   285  
   286  func (t *trackingReadCloser) Close() error {
   287  	t.Closed = true
   288  	return t.ReadCloser.Close()
   289  }
   290  
   291  // Private details.
   292  
   293  func fast() retry.Iterator {
   294  	return &retry.Limited{Retries: 3}
   295  }