github.com/supr/packer@v0.3.10-0.20131015195147-7b09e24ac3c1/common/download.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"crypto/sha1"
     7  	"crypto/sha256"
     8  	"crypto/sha512"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"io"
    14  	"log"
    15  	"net/http"
    16  	"net/url"
    17  	"os"
    18  	"runtime"
    19  )
    20  
    21  // DownloadConfig is the configuration given to instantiate a new
    22  // download instance. Once a configuration is used to instantiate
    23  // a download client, it must not be modified.
    24  type DownloadConfig struct {
    25  	// The source URL in the form of a string.
    26  	Url string
    27  
    28  	// This is the path to download the file to.
    29  	TargetPath string
    30  
    31  	// DownloaderMap maps a schema to a Download.
    32  	DownloaderMap map[string]Downloader
    33  
    34  	// If true, this will copy even a local file to the target
    35  	// location. If false, then it will "download" the file by just
    36  	// returning the local path to the file.
    37  	CopyFile bool
    38  
    39  	// The hashing implementation to use to checksum the downloaded file.
    40  	Hash hash.Hash
    41  
    42  	// The checksum for the downloaded file. The hash implementation configuration
    43  	// for the downloader will be used to verify with this checksum after
    44  	// it is downloaded.
    45  	Checksum []byte
    46  }
    47  
    48  // A DownloadClient helps download, verify checksums, etc.
    49  type DownloadClient struct {
    50  	config     *DownloadConfig
    51  	downloader Downloader
    52  }
    53  
    54  // HashForType returns the Hash implementation for the given string
    55  // type, or nil if the type is not supported.
    56  func HashForType(t string) hash.Hash {
    57  	switch t {
    58  	case "md5":
    59  		return md5.New()
    60  	case "sha1":
    61  		return sha1.New()
    62  	case "sha256":
    63  		return sha256.New()
    64  	case "sha512":
    65  		return sha512.New()
    66  	default:
    67  		return nil
    68  	}
    69  }
    70  
    71  // NewDownloadClient returns a new DownloadClient for the given
    72  // configuration.
    73  func NewDownloadClient(c *DownloadConfig) *DownloadClient {
    74  	if c.DownloaderMap == nil {
    75  		c.DownloaderMap = map[string]Downloader{
    76  			"http": new(HTTPDownloader),
    77  		}
    78  	}
    79  
    80  	return &DownloadClient{config: c}
    81  }
    82  
    83  // A downloader is responsible for actually taking a remote URL and
    84  // downloading it.
    85  type Downloader interface {
    86  	Cancel()
    87  	Download(io.Writer, *url.URL) error
    88  	Progress() uint
    89  	Total() uint
    90  }
    91  
    92  func (d *DownloadClient) Cancel() {
    93  	// TODO(mitchellh): Implement
    94  }
    95  
    96  func (d *DownloadClient) Get() (string, error) {
    97  	// If we already have the file and it matches, then just return the target path.
    98  	if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
    99  		log.Println("Initial checksum matched, no download needed.")
   100  		return d.config.TargetPath, nil
   101  	}
   102  
   103  	url, err := url.Parse(d.config.Url)
   104  	if err != nil {
   105  		return "", err
   106  	}
   107  
   108  	log.Printf("Parsed URL: %#v", url)
   109  
   110  	// Files when we don't copy the file are special cased.
   111  	var finalPath string
   112  	if url.Scheme == "file" && !d.config.CopyFile {
   113  		finalPath = url.Path
   114  
   115  		// Remove forward slash on absolute Windows file URLs before processing
   116  		if runtime.GOOS == "windows" && finalPath[0] == '/' {
   117  			finalPath = finalPath[1:len(finalPath)]
   118  		}
   119  	} else {
   120  		finalPath = d.config.TargetPath
   121  
   122  		var ok bool
   123  		d.downloader, ok = d.config.DownloaderMap[url.Scheme]
   124  		if !ok {
   125  			return "", fmt.Errorf("No downloader for scheme: %s", url.Scheme)
   126  		}
   127  
   128  		// Otherwise, download using the downloader.
   129  		f, err := os.Create(finalPath)
   130  		if err != nil {
   131  			return "", err
   132  		}
   133  		defer f.Close()
   134  
   135  		log.Printf("Downloading: %s", url.String())
   136  		err = d.downloader.Download(f, url)
   137  		if err != nil {
   138  			return "", err
   139  		}
   140  	}
   141  
   142  	if d.config.Hash != nil {
   143  		var verify bool
   144  		verify, err = d.VerifyChecksum(finalPath)
   145  		if err == nil && !verify {
   146  			err = fmt.Errorf("checksums didn't match expected: %s", hex.EncodeToString(d.config.Checksum))
   147  		}
   148  	}
   149  
   150  	return finalPath, err
   151  }
   152  
   153  // PercentProgress returns the download progress as a percentage.
   154  func (d *DownloadClient) PercentProgress() int {
   155  	if d.downloader == nil {
   156  		return -1
   157  	}
   158  
   159  	return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100)
   160  }
   161  
   162  // VerifyChecksum tests that the path matches the checksum for the
   163  // download.
   164  func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
   165  	if d.config.Checksum == nil || d.config.Hash == nil {
   166  		return false, errors.New("Checksum or Hash isn't set on download.")
   167  	}
   168  
   169  	f, err := os.Open(path)
   170  	if err != nil {
   171  		return false, err
   172  	}
   173  	defer f.Close()
   174  
   175  	log.Printf("Verifying checksum of %s", path)
   176  	d.config.Hash.Reset()
   177  	io.Copy(d.config.Hash, f)
   178  	return bytes.Compare(d.config.Hash.Sum(nil), d.config.Checksum) == 0, nil
   179  }
   180  
   181  // HTTPDownloader is an implementation of Downloader that downloads
   182  // files over HTTP.
   183  type HTTPDownloader struct {
   184  	progress uint
   185  	total    uint
   186  }
   187  
   188  func (*HTTPDownloader) Cancel() {
   189  	// TODO(mitchellh): Implement
   190  }
   191  
   192  func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error {
   193  	log.Printf("Starting download: %s", src.String())
   194  	req, err := http.NewRequest("GET", src.String(), nil)
   195  	if err != nil {
   196  		return err
   197  	}
   198  
   199  	httpClient := &http.Client{
   200  		Transport: &http.Transport{
   201  			Proxy: http.ProxyFromEnvironment,
   202  		},
   203  	}
   204  
   205  	resp, err := httpClient.Do(req)
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	if resp.StatusCode != 200 {
   211  		log.Printf(
   212  			"Non-200 status code: %d. Getting error body.", resp.StatusCode)
   213  
   214  		errorBody := new(bytes.Buffer)
   215  		io.Copy(errorBody, resp.Body)
   216  		return fmt.Errorf("HTTP error '%d'! Remote side responded:\n%s",
   217  			resp.StatusCode, errorBody.String())
   218  	}
   219  
   220  	d.progress = 0
   221  	d.total = uint(resp.ContentLength)
   222  
   223  	var buffer [4096]byte
   224  	for {
   225  		n, err := resp.Body.Read(buffer[:])
   226  		if err != nil && err != io.EOF {
   227  			return err
   228  		}
   229  
   230  		d.progress += uint(n)
   231  
   232  		if _, werr := dst.Write(buffer[:n]); werr != nil {
   233  			return werr
   234  		}
   235  
   236  		if err == io.EOF {
   237  			break
   238  		}
   239  	}
   240  
   241  	return nil
   242  }
   243  
   244  func (d *HTTPDownloader) Progress() uint {
   245  	return d.progress
   246  }
   247  
   248  func (d *HTTPDownloader) Total() uint {
   249  	return d.total
   250  }