github.com/aspring/packer@v0.8.1-0.20150629211158-9db281ac0f89/common/download.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"crypto/sha1"
     7  	"crypto/sha256"
     8  	"crypto/sha512"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"io"
    14  	"log"
    15  	"net/http"
    16  	"net/url"
    17  	"os"
    18  	"runtime"
    19  )
    20  
    21  // DownloadConfig is the configuration given to instantiate a new
    22  // download instance. Once a configuration is used to instantiate
    23  // a download client, it must not be modified.
    24  type DownloadConfig struct {
    25  	// The source URL in the form of a string.
    26  	Url string
    27  
    28  	// This is the path to download the file to.
    29  	TargetPath string
    30  
    31  	// DownloaderMap maps a schema to a Download.
    32  	DownloaderMap map[string]Downloader
    33  
    34  	// If true, this will copy even a local file to the target
    35  	// location. If false, then it will "download" the file by just
    36  	// returning the local path to the file.
    37  	CopyFile bool
    38  
    39  	// The hashing implementation to use to checksum the downloaded file.
    40  	Hash hash.Hash
    41  
    42  	// The checksum for the downloaded file. The hash implementation configuration
    43  	// for the downloader will be used to verify with this checksum after
    44  	// it is downloaded.
    45  	Checksum []byte
    46  
    47  	// What to use for the user agent for HTTP requests. If set to "", use the
    48  	// default user agent provided by Go.
    49  	UserAgent string
    50  }
    51  
    52  // A DownloadClient helps download, verify checksums, etc.
    53  type DownloadClient struct {
    54  	config     *DownloadConfig
    55  	downloader Downloader
    56  }
    57  
    58  // HashForType returns the Hash implementation for the given string
    59  // type, or nil if the type is not supported.
    60  func HashForType(t string) hash.Hash {
    61  	switch t {
    62  	case "md5":
    63  		return md5.New()
    64  	case "sha1":
    65  		return sha1.New()
    66  	case "sha256":
    67  		return sha256.New()
    68  	case "sha512":
    69  		return sha512.New()
    70  	default:
    71  		return nil
    72  	}
    73  }
    74  
    75  // NewDownloadClient returns a new DownloadClient for the given
    76  // configuration.
    77  func NewDownloadClient(c *DownloadConfig) *DownloadClient {
    78  	if c.DownloaderMap == nil {
    79  		c.DownloaderMap = map[string]Downloader{
    80  			"http":  &HTTPDownloader{userAgent: c.UserAgent},
    81  			"https": &HTTPDownloader{userAgent: c.UserAgent},
    82  		}
    83  	}
    84  
    85  	return &DownloadClient{config: c}
    86  }
    87  
    88  // A downloader is responsible for actually taking a remote URL and
    89  // downloading it.
    90  type Downloader interface {
    91  	Cancel()
    92  	Download(*os.File, *url.URL) error
    93  	Progress() uint
    94  	Total() uint
    95  }
    96  
    97  func (d *DownloadClient) Cancel() {
    98  	// TODO(mitchellh): Implement
    99  }
   100  
   101  func (d *DownloadClient) Get() (string, error) {
   102  	// If we already have the file and it matches, then just return the target path.
   103  	if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
   104  		log.Println("Initial checksum matched, no download needed.")
   105  		return d.config.TargetPath, nil
   106  	}
   107  
   108  	url, err := url.Parse(d.config.Url)
   109  	if err != nil {
   110  		return "", err
   111  	}
   112  
   113  	log.Printf("Parsed URL: %#v", url)
   114  
   115  	// Files when we don't copy the file are special cased.
   116  	var f *os.File
   117  	var finalPath string
   118  	if url.Scheme == "file" && !d.config.CopyFile {
   119  		finalPath = url.Path
   120  
   121  		// Remove forward slash on absolute Windows file URLs before processing
   122  		if runtime.GOOS == "windows" && len(finalPath) > 0 && finalPath[0] == '/' {
   123  			finalPath = finalPath[1:len(finalPath)]
   124  		}
   125  	} else {
   126  		finalPath = d.config.TargetPath
   127  
   128  		var ok bool
   129  		d.downloader, ok = d.config.DownloaderMap[url.Scheme]
   130  		if !ok {
   131  			return "", fmt.Errorf("No downloader for scheme: %s", url.Scheme)
   132  		}
   133  
   134  		// Otherwise, download using the downloader.
   135  		f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666))
   136  		if err != nil {
   137  			return "", err
   138  		}
   139  
   140  		log.Printf("Downloading: %s", url.String())
   141  		err = d.downloader.Download(f, url)
   142  		f.Close()
   143  		if err != nil {
   144  			return "", err
   145  		}
   146  	}
   147  
   148  	if d.config.Hash != nil {
   149  		var verify bool
   150  		verify, err = d.VerifyChecksum(finalPath)
   151  		if err == nil && !verify {
   152  			// Delete the file
   153  			os.Remove(finalPath)
   154  
   155  			err = fmt.Errorf(
   156  				"checksums didn't match expected: %s",
   157  				hex.EncodeToString(d.config.Checksum))
   158  		}
   159  	}
   160  
   161  	return finalPath, err
   162  }
   163  
   164  // PercentProgress returns the download progress as a percentage.
   165  func (d *DownloadClient) PercentProgress() int {
   166  	if d.downloader == nil {
   167  		return -1
   168  	}
   169  
   170  	return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100)
   171  }
   172  
   173  // VerifyChecksum tests that the path matches the checksum for the
   174  // download.
   175  func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
   176  	if d.config.Checksum == nil || d.config.Hash == nil {
   177  		return false, errors.New("Checksum or Hash isn't set on download.")
   178  	}
   179  
   180  	f, err := os.Open(path)
   181  	if err != nil {
   182  		return false, err
   183  	}
   184  	defer f.Close()
   185  
   186  	log.Printf("Verifying checksum of %s", path)
   187  	d.config.Hash.Reset()
   188  	io.Copy(d.config.Hash, f)
   189  	return bytes.Compare(d.config.Hash.Sum(nil), d.config.Checksum) == 0, nil
   190  }
   191  
   192  // HTTPDownloader is an implementation of Downloader that downloads
   193  // files over HTTP.
   194  type HTTPDownloader struct {
   195  	progress  uint
   196  	total     uint
   197  	userAgent string
   198  }
   199  
   200  func (*HTTPDownloader) Cancel() {
   201  	// TODO(mitchellh): Implement
   202  }
   203  
   204  func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
   205  	log.Printf("Starting download: %s", src.String())
   206  
   207  	// Seek to the beginning by default
   208  	if _, err := dst.Seek(0, 0); err != nil {
   209  		return err
   210  	}
   211  
   212  	// Reset our progress
   213  	d.progress = 0
   214  
   215  	// Make the request. We first make a HEAD request so we can check
   216  	// if the server supports range queries. If the server/URL doesn't
   217  	// support HEAD requests, we just fall back to GET.
   218  	req, err := http.NewRequest("HEAD", src.String(), nil)
   219  	if err != nil {
   220  		return err
   221  	}
   222  
   223  	if d.userAgent != "" {
   224  		req.Header.Set("User-Agent", d.userAgent)
   225  	}
   226  
   227  	httpClient := &http.Client{
   228  		Transport: &http.Transport{
   229  			Proxy: http.ProxyFromEnvironment,
   230  		},
   231  	}
   232  
   233  	resp, err := httpClient.Do(req)
   234  	if err == nil && (resp.StatusCode >= 200 && resp.StatusCode < 300) {
   235  		// If the HEAD request succeeded, then attempt to set the range
   236  		// query if we can.
   237  		if resp.Header.Get("Accept-Ranges") == "bytes" {
   238  			if fi, err := dst.Stat(); err == nil {
   239  				if _, err = dst.Seek(0, os.SEEK_END); err == nil {
   240  					req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))
   241  					d.progress = uint(fi.Size())
   242  				}
   243  			}
   244  		}
   245  	}
   246  
   247  	// Set the request to GET now, and redo the query to download
   248  	req.Method = "GET"
   249  
   250  	resp, err = httpClient.Do(req)
   251  	if err != nil {
   252  		return err
   253  	}
   254  
   255  	d.total = d.progress + uint(resp.ContentLength)
   256  	var buffer [4096]byte
   257  	for {
   258  		n, err := resp.Body.Read(buffer[:])
   259  		if err != nil && err != io.EOF {
   260  			return err
   261  		}
   262  
   263  		d.progress += uint(n)
   264  
   265  		if _, werr := dst.Write(buffer[:n]); werr != nil {
   266  			return werr
   267  		}
   268  
   269  		if err == io.EOF {
   270  			break
   271  		}
   272  	}
   273  
   274  	return nil
   275  }
   276  
   277  func (d *HTTPDownloader) Progress() uint {
   278  	return d.progress
   279  }
   280  
   281  func (d *HTTPDownloader) Total() uint {
   282  	return d.total
   283  }