github.com/cli/cli@v1.14.1-0.20210902173923-1af6a669e342/api/cache.go (about)

     1  package api
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/sha256"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"net/http"
    12  	"os"
    13  	"path/filepath"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  )
    18  
    19  func NewCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Client {
    20  	cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache")
    21  	return &http.Client{
    22  		Transport: CacheResponse(cacheTTL, cacheDir)(httpClient.Transport),
    23  	}
    24  }
    25  
    26  func isCacheableRequest(req *http.Request) bool {
    27  	if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") {
    28  		return true
    29  	}
    30  
    31  	if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") {
    32  		return true
    33  	}
    34  
    35  	return false
    36  }
    37  
    38  func isCacheableResponse(res *http.Response) bool {
    39  	return res.StatusCode < 500 && res.StatusCode != 403
    40  }
    41  
    42  // CacheResponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time
    43  func CacheResponse(ttl time.Duration, dir string) ClientOption {
    44  	fs := fileStorage{
    45  		dir: dir,
    46  		ttl: ttl,
    47  		mu:  &sync.RWMutex{},
    48  	}
    49  
    50  	return func(tr http.RoundTripper) http.RoundTripper {
    51  		return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
    52  			if !isCacheableRequest(req) {
    53  				return tr.RoundTrip(req)
    54  			}
    55  
    56  			key, keyErr := cacheKey(req)
    57  			if keyErr == nil {
    58  				if res, err := fs.read(key); err == nil {
    59  					res.Request = req
    60  					return res, nil
    61  				}
    62  			}
    63  
    64  			res, err := tr.RoundTrip(req)
    65  			if err == nil && keyErr == nil && isCacheableResponse(res) {
    66  				_ = fs.store(key, res)
    67  			}
    68  			return res, err
    69  		}}
    70  	}
    71  }
    72  
    73  func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) {
    74  	b := &bytes.Buffer{}
    75  	nr := io.TeeReader(r, b)
    76  	return ioutil.NopCloser(b), &readCloser{
    77  		Reader: nr,
    78  		Closer: r,
    79  	}
    80  }
    81  
    82  type readCloser struct {
    83  	io.Reader
    84  	io.Closer
    85  }
    86  
    87  func cacheKey(req *http.Request) (string, error) {
    88  	h := sha256.New()
    89  	fmt.Fprintf(h, "%s:", req.Method)
    90  	fmt.Fprintf(h, "%s:", req.URL.String())
    91  	fmt.Fprintf(h, "%s:", req.Header.Get("Accept"))
    92  	fmt.Fprintf(h, "%s:", req.Header.Get("Authorization"))
    93  
    94  	if req.Body != nil {
    95  		var bodyCopy io.ReadCloser
    96  		req.Body, bodyCopy = copyStream(req.Body)
    97  		defer bodyCopy.Close()
    98  		if _, err := io.Copy(h, bodyCopy); err != nil {
    99  			return "", err
   100  		}
   101  	}
   102  
   103  	digest := h.Sum(nil)
   104  	return fmt.Sprintf("%x", digest), nil
   105  }
   106  
   107  type fileStorage struct {
   108  	dir string
   109  	ttl time.Duration
   110  	mu  *sync.RWMutex
   111  }
   112  
   113  func (fs *fileStorage) filePath(key string) string {
   114  	if len(key) >= 6 {
   115  		return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:])
   116  	}
   117  	return filepath.Join(fs.dir, key)
   118  }
   119  
   120  func (fs *fileStorage) read(key string) (*http.Response, error) {
   121  	cacheFile := fs.filePath(key)
   122  
   123  	fs.mu.RLock()
   124  	defer fs.mu.RUnlock()
   125  
   126  	f, err := os.Open(cacheFile)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  	defer f.Close()
   131  
   132  	stat, err := f.Stat()
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  
   137  	age := time.Since(stat.ModTime())
   138  	if age > fs.ttl {
   139  		return nil, errors.New("cache expired")
   140  	}
   141  
   142  	body := &bytes.Buffer{}
   143  	_, err = io.Copy(body, f)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  
   148  	res, err := http.ReadResponse(bufio.NewReader(body), nil)
   149  	return res, err
   150  }
   151  
   152  func (fs *fileStorage) store(key string, res *http.Response) error {
   153  	cacheFile := fs.filePath(key)
   154  
   155  	fs.mu.Lock()
   156  	defer fs.mu.Unlock()
   157  
   158  	err := os.MkdirAll(filepath.Dir(cacheFile), 0755)
   159  	if err != nil {
   160  		return err
   161  	}
   162  
   163  	f, err := os.OpenFile(cacheFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
   164  	if err != nil {
   165  		return err
   166  	}
   167  	defer f.Close()
   168  
   169  	var origBody io.ReadCloser
   170  	if res.Body != nil {
   171  		origBody, res.Body = copyStream(res.Body)
   172  		defer res.Body.Close()
   173  	}
   174  	err = res.Write(f)
   175  	if origBody != nil {
   176  		res.Body = origBody
   177  	}
   178  	return err
   179  }