github.com/cli/cli@v1.14.1-0.20210902173923-1af6a669e342/api/cache_test.go (about) 1 package api 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "net/http" 9 "path/filepath" 10 "testing" 11 "time" 12 13 "github.com/stretchr/testify/assert" 14 "github.com/stretchr/testify/require" 15 ) 16 17 func Test_CacheResponse(t *testing.T) { 18 counter := 0 19 fakeHTTP := funcTripper{ 20 roundTrip: func(req *http.Request) (*http.Response, error) { 21 counter += 1 22 body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String()) 23 status := 200 24 if req.URL.Path == "/error" { 25 status = 500 26 } 27 return &http.Response{ 28 StatusCode: status, 29 Body: ioutil.NopCloser(bytes.NewBufferString(body)), 30 }, nil 31 }, 32 } 33 34 cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache") 35 httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP), CacheResponse(time.Minute, cacheDir)) 36 37 do := func(method, url string, body io.Reader) (string, error) { 38 req, err := http.NewRequest(method, url, body) 39 if err != nil { 40 return "", err 41 } 42 res, err := httpClient.Do(req) 43 if err != nil { 44 return "", err 45 } 46 defer res.Body.Close() 47 resBody, err := ioutil.ReadAll(res.Body) 48 if err != nil { 49 err = fmt.Errorf("ReadAll: %w", err) 50 } 51 return string(resBody), err 52 } 53 54 var res string 55 var err error 56 57 res, err = do("GET", "http://example.com/path", nil) 58 require.NoError(t, err) 59 assert.Equal(t, "1: GET http://example.com/path", res) 60 res, err = do("GET", "http://example.com/path", nil) 61 require.NoError(t, err) 62 assert.Equal(t, "1: GET http://example.com/path", res) 63 64 res, err = do("GET", "http://example.com/path2", nil) 65 require.NoError(t, err) 66 assert.Equal(t, "2: GET http://example.com/path2", res) 67 68 res, err = do("POST", "http://example.com/path2", nil) 69 require.NoError(t, err) 70 assert.Equal(t, "3: POST http://example.com/path2", res) 71 72 res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) 73 require.NoError(t, err) 74 assert.Equal(t, "4: POST http://example.com/graphql", res) 75 res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) 76 require.NoError(t, err) 77 assert.Equal(t, "4: POST http://example.com/graphql", res) 78 79 res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`)) 80 require.NoError(t, err) 81 assert.Equal(t, "5: POST http://example.com/graphql", res) 82 83 res, err = do("GET", "http://example.com/error", nil) 84 require.NoError(t, err) 85 assert.Equal(t, "6: GET http://example.com/error", res) 86 res, err = do("GET", "http://example.com/error", nil) 87 require.NoError(t, err) 88 assert.Equal(t, "7: GET http://example.com/error", res) 89 }