github.com/YousefHaggyHeroku/pack@v1.5.5/internal/blob/downloader.go (about)

     1  package blob
     2  
     3  import (
     4  	"context"
     5  	"crypto/sha256"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"net/url"
    11  	"os"
    12  	"path/filepath"
    13  
    14  	"github.com/mitchellh/ioprogress"
    15  	"github.com/pkg/errors"
    16  
    17  	"github.com/YousefHaggyHeroku/pack/internal/paths"
    18  	"github.com/YousefHaggyHeroku/pack/internal/style"
    19  	"github.com/YousefHaggyHeroku/pack/logging"
    20  )
    21  
    22  const (
    23  	cacheDirPrefix = "c"
    24  	cacheVersion   = "2"
    25  )
    26  
    27  type downloader struct {
    28  	logger       logging.Logger
    29  	baseCacheDir string
    30  }
    31  
    32  func NewDownloader(logger logging.Logger, baseCacheDir string) *downloader { //nolint:golint,gosimple
    33  	return &downloader{
    34  		logger:       logger,
    35  		baseCacheDir: baseCacheDir,
    36  	}
    37  }
    38  
    39  func (d *downloader) Download(ctx context.Context, pathOrURI string) (Blob, error) {
    40  	if paths.IsURI(pathOrURI) {
    41  		parsedURL, err := url.Parse(pathOrURI)
    42  		if err != nil {
    43  			return nil, errors.Wrapf(err, "parsing path/uri %s", style.Symbol(pathOrURI))
    44  		}
    45  
    46  		var path string
    47  		switch parsedURL.Scheme {
    48  		case "file":
    49  			path, err = paths.URIToFilePath(pathOrURI)
    50  		case "http", "https":
    51  			path, err = d.handleHTTP(ctx, pathOrURI)
    52  		default:
    53  			err = fmt.Errorf("unsupported protocol %s in URI %s", style.Symbol(parsedURL.Scheme), style.Symbol(pathOrURI))
    54  		}
    55  		if err != nil {
    56  			return nil, err
    57  		}
    58  
    59  		return &blob{path: path}, nil
    60  	}
    61  
    62  	path := d.handleFile(pathOrURI)
    63  
    64  	return &blob{path: path}, nil
    65  }
    66  
    67  func (d *downloader) handleFile(path string) string {
    68  	path, err := filepath.Abs(path)
    69  	if err != nil {
    70  		return ""
    71  	}
    72  
    73  	return path
    74  }
    75  
    76  func (d *downloader) handleHTTP(ctx context.Context, uri string) (string, error) {
    77  	cacheDir := d.versionedCacheDir()
    78  
    79  	if err := os.MkdirAll(cacheDir, 0755); err != nil {
    80  		return "", err
    81  	}
    82  
    83  	cachePath := filepath.Join(cacheDir, fmt.Sprintf("%x", sha256.Sum256([]byte(uri))))
    84  
    85  	etagFile := cachePath + ".etag"
    86  	etagExists, err := fileExists(etagFile)
    87  	if err != nil {
    88  		return "", err
    89  	}
    90  
    91  	etag := ""
    92  	if etagExists {
    93  		bytes, err := ioutil.ReadFile(etagFile)
    94  		if err != nil {
    95  			return "", err
    96  		}
    97  		etag = string(bytes)
    98  	}
    99  
   100  	reader, etag, err := d.downloadAsStream(ctx, uri, etag)
   101  	if err != nil {
   102  		return "", err
   103  	} else if reader == nil {
   104  		return cachePath, nil
   105  	}
   106  	defer reader.Close()
   107  
   108  	fh, err := os.Create(cachePath)
   109  	if err != nil {
   110  		return "", errors.Wrapf(err, "create cache path %s", style.Symbol(cachePath))
   111  	}
   112  	defer fh.Close()
   113  
   114  	_, err = io.Copy(fh, reader)
   115  	if err != nil {
   116  		return "", errors.Wrap(err, "writing cache")
   117  	}
   118  
   119  	if err = ioutil.WriteFile(etagFile, []byte(etag), 0744); err != nil {
   120  		return "", errors.Wrap(err, "writing etag")
   121  	}
   122  
   123  	return cachePath, nil
   124  }
   125  
   126  func (d *downloader) downloadAsStream(ctx context.Context, uri string, etag string) (io.ReadCloser, string, error) {
   127  	req, err := http.NewRequest("GET", uri, nil)
   128  	if err != nil {
   129  		return nil, "", err
   130  	}
   131  	req = req.WithContext(ctx)
   132  
   133  	if etag != "" {
   134  		req.Header.Set("If-None-Match", etag)
   135  	}
   136  
   137  	resp, err := (&http.Client{}).Do(req) //nolint:bodyclose
   138  	if err != nil {
   139  		return nil, "", err
   140  	}
   141  
   142  	if resp.StatusCode >= 200 && resp.StatusCode < 300 {
   143  		d.logger.Infof("Downloading from %s", style.Symbol(uri))
   144  		return withProgress(logging.GetWriterForLevel(d.logger, logging.InfoLevel), resp.Body, resp.ContentLength), resp.Header.Get("Etag"), nil
   145  	}
   146  
   147  	if resp.StatusCode == 304 {
   148  		d.logger.Debugf("Using cached version of %s", style.Symbol(uri))
   149  		return nil, etag, nil
   150  	}
   151  
   152  	return nil, "", fmt.Errorf(
   153  		"could not download from %s, code http status %s",
   154  		style.Symbol(uri), style.SymbolF("%d", resp.StatusCode),
   155  	)
   156  }
   157  
   158  func withProgress(writer io.Writer, rc io.ReadCloser, length int64) io.ReadCloser {
   159  	return &progressReader{
   160  		Closer: rc,
   161  		Reader: &ioprogress.Reader{
   162  			Reader:   rc,
   163  			Size:     length,
   164  			DrawFunc: ioprogress.DrawTerminalf(writer, ioprogress.DrawTextFormatBytes),
   165  		},
   166  	}
   167  }
   168  
   169  type progressReader struct {
   170  	*ioprogress.Reader
   171  	io.Closer
   172  }
   173  
   174  func (d *downloader) versionedCacheDir() string {
   175  	return filepath.Join(d.baseCacheDir, cacheDirPrefix+cacheVersion)
   176  }
   177  
   178  func fileExists(file string) (bool, error) {
   179  	_, err := os.Stat(file)
   180  	if err != nil {
   181  		if os.IsNotExist(err) {
   182  			return false, nil
   183  		}
   184  		return false, err
   185  	}
   186  	return true, nil
   187  }