github.com/cli/cli@v1.14.1-0.20210902173923-1af6a669e342/api/cache_test.go (about)

     1  package api
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"path/filepath"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  func Test_CacheResponse(t *testing.T) {
    18  	counter := 0
    19  	fakeHTTP := funcTripper{
    20  		roundTrip: func(req *http.Request) (*http.Response, error) {
    21  			counter += 1
    22  			body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
    23  			status := 200
    24  			if req.URL.Path == "/error" {
    25  				status = 500
    26  			}
    27  			return &http.Response{
    28  				StatusCode: status,
    29  				Body:       ioutil.NopCloser(bytes.NewBufferString(body)),
    30  			}, nil
    31  		},
    32  	}
    33  
    34  	cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache")
    35  	httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP), CacheResponse(time.Minute, cacheDir))
    36  
    37  	do := func(method, url string, body io.Reader) (string, error) {
    38  		req, err := http.NewRequest(method, url, body)
    39  		if err != nil {
    40  			return "", err
    41  		}
    42  		res, err := httpClient.Do(req)
    43  		if err != nil {
    44  			return "", err
    45  		}
    46  		defer res.Body.Close()
    47  		resBody, err := ioutil.ReadAll(res.Body)
    48  		if err != nil {
    49  			err = fmt.Errorf("ReadAll: %w", err)
    50  		}
    51  		return string(resBody), err
    52  	}
    53  
    54  	var res string
    55  	var err error
    56  
    57  	res, err = do("GET", "http://example.com/path", nil)
    58  	require.NoError(t, err)
    59  	assert.Equal(t, "1: GET http://example.com/path", res)
    60  	res, err = do("GET", "http://example.com/path", nil)
    61  	require.NoError(t, err)
    62  	assert.Equal(t, "1: GET http://example.com/path", res)
    63  
    64  	res, err = do("GET", "http://example.com/path2", nil)
    65  	require.NoError(t, err)
    66  	assert.Equal(t, "2: GET http://example.com/path2", res)
    67  
    68  	res, err = do("POST", "http://example.com/path2", nil)
    69  	require.NoError(t, err)
    70  	assert.Equal(t, "3: POST http://example.com/path2", res)
    71  
    72  	res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
    73  	require.NoError(t, err)
    74  	assert.Equal(t, "4: POST http://example.com/graphql", res)
    75  	res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
    76  	require.NoError(t, err)
    77  	assert.Equal(t, "4: POST http://example.com/graphql", res)
    78  
    79  	res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`))
    80  	require.NoError(t, err)
    81  	assert.Equal(t, "5: POST http://example.com/graphql", res)
    82  
    83  	res, err = do("GET", "http://example.com/error", nil)
    84  	require.NoError(t, err)
    85  	assert.Equal(t, "6: GET http://example.com/error", res)
    86  	res, err = do("GET", "http://example.com/error", nil)
    87  	require.NoError(t, err)
    88  	assert.Equal(t, "7: GET http://example.com/error", res)
    89  }