github.com/sudharsh/packer@v0.4.1/common/download.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"crypto/sha1"
     7  	"crypto/sha256"
     8  	"crypto/sha512"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"io"
    14  	"log"
    15  	"net/http"
    16  	"net/url"
    17  	"os"
    18  	"runtime"
    19  )
    20  
    21  // DownloadConfig is the configuration given to instantiate a new
    22  // download instance. Once a configuration is used to instantiate
    23  // a download client, it must not be modified.
    24  type DownloadConfig struct {
    25  	// The source URL in the form of a string.
    26  	Url string
    27  
    28  	// This is the path to download the file to.
    29  	TargetPath string
    30  
    31  	// DownloaderMap maps a schema to a Download.
    32  	DownloaderMap map[string]Downloader
    33  
    34  	// If true, this will copy even a local file to the target
    35  	// location. If false, then it will "download" the file by just
    36  	// returning the local path to the file.
    37  	CopyFile bool
    38  
    39  	// The hashing implementation to use to checksum the downloaded file.
    40  	Hash hash.Hash
    41  
    42  	// The checksum for the downloaded file. The hash implementation configuration
    43  	// for the downloader will be used to verify with this checksum after
    44  	// it is downloaded.
    45  	Checksum []byte
    46  }
    47  
    48  // A DownloadClient helps download, verify checksums, etc.
    49  type DownloadClient struct {
    50  	config     *DownloadConfig
    51  	downloader Downloader
    52  }
    53  
    54  // HashForType returns the Hash implementation for the given string
    55  // type, or nil if the type is not supported.
    56  func HashForType(t string) hash.Hash {
    57  	switch t {
    58  	case "md5":
    59  		return md5.New()
    60  	case "sha1":
    61  		return sha1.New()
    62  	case "sha256":
    63  		return sha256.New()
    64  	case "sha512":
    65  		return sha512.New()
    66  	default:
    67  		return nil
    68  	}
    69  }
    70  
    71  // NewDownloadClient returns a new DownloadClient for the given
    72  // configuration.
    73  func NewDownloadClient(c *DownloadConfig) *DownloadClient {
    74  	if c.DownloaderMap == nil {
    75  		c.DownloaderMap = map[string]Downloader{
    76  			"http":  new(HTTPDownloader),
    77  			"https": new(HTTPDownloader),
    78  		}
    79  	}
    80  
    81  	return &DownloadClient{config: c}
    82  }
    83  
    84  // A downloader is responsible for actually taking a remote URL and
    85  // downloading it.
    86  type Downloader interface {
    87  	Cancel()
    88  	Download(io.Writer, *url.URL) error
    89  	Progress() uint
    90  	Total() uint
    91  }
    92  
    93  func (d *DownloadClient) Cancel() {
    94  	// TODO(mitchellh): Implement
    95  }
    96  
    97  func (d *DownloadClient) Get() (string, error) {
    98  	// If we already have the file and it matches, then just return the target path.
    99  	if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
   100  		log.Println("Initial checksum matched, no download needed.")
   101  		return d.config.TargetPath, nil
   102  	}
   103  
   104  	url, err := url.Parse(d.config.Url)
   105  	if err != nil {
   106  		return "", err
   107  	}
   108  
   109  	log.Printf("Parsed URL: %#v", url)
   110  
   111  	// Files when we don't copy the file are special cased.
   112  	var finalPath string
   113  	if url.Scheme == "file" && !d.config.CopyFile {
   114  		finalPath = url.Path
   115  
   116  		// Remove forward slash on absolute Windows file URLs before processing
   117  		if runtime.GOOS == "windows" && finalPath[0] == '/' {
   118  			finalPath = finalPath[1:len(finalPath)]
   119  		}
   120  	} else {
   121  		finalPath = d.config.TargetPath
   122  
   123  		var ok bool
   124  		d.downloader, ok = d.config.DownloaderMap[url.Scheme]
   125  		if !ok {
   126  			return "", fmt.Errorf("No downloader for scheme: %s", url.Scheme)
   127  		}
   128  
   129  		// Otherwise, download using the downloader.
   130  		f, err := os.Create(finalPath)
   131  		if err != nil {
   132  			return "", err
   133  		}
   134  		defer f.Close()
   135  
   136  		log.Printf("Downloading: %s", url.String())
   137  		err = d.downloader.Download(f, url)
   138  		if err != nil {
   139  			return "", err
   140  		}
   141  	}
   142  
   143  	if d.config.Hash != nil {
   144  		var verify bool
   145  		verify, err = d.VerifyChecksum(finalPath)
   146  		if err == nil && !verify {
   147  			err = fmt.Errorf("checksums didn't match expected: %s", hex.EncodeToString(d.config.Checksum))
   148  		}
   149  	}
   150  
   151  	return finalPath, err
   152  }
   153  
   154  // PercentProgress returns the download progress as a percentage.
   155  func (d *DownloadClient) PercentProgress() int {
   156  	if d.downloader == nil {
   157  		return -1
   158  	}
   159  
   160  	return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100)
   161  }
   162  
   163  // VerifyChecksum tests that the path matches the checksum for the
   164  // download.
   165  func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
   166  	if d.config.Checksum == nil || d.config.Hash == nil {
   167  		return false, errors.New("Checksum or Hash isn't set on download.")
   168  	}
   169  
   170  	f, err := os.Open(path)
   171  	if err != nil {
   172  		return false, err
   173  	}
   174  	defer f.Close()
   175  
   176  	log.Printf("Verifying checksum of %s", path)
   177  	d.config.Hash.Reset()
   178  	io.Copy(d.config.Hash, f)
   179  	return bytes.Compare(d.config.Hash.Sum(nil), d.config.Checksum) == 0, nil
   180  }
   181  
   182  // HTTPDownloader is an implementation of Downloader that downloads
   183  // files over HTTP.
   184  type HTTPDownloader struct {
   185  	progress uint
   186  	total    uint
   187  }
   188  
   189  func (*HTTPDownloader) Cancel() {
   190  	// TODO(mitchellh): Implement
   191  }
   192  
   193  func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error {
   194  	log.Printf("Starting download: %s", src.String())
   195  	req, err := http.NewRequest("GET", src.String(), nil)
   196  	if err != nil {
   197  		return err
   198  	}
   199  
   200  	httpClient := &http.Client{
   201  		Transport: &http.Transport{
   202  			Proxy: http.ProxyFromEnvironment,
   203  		},
   204  	}
   205  
   206  	resp, err := httpClient.Do(req)
   207  	if err != nil {
   208  		return err
   209  	}
   210  
   211  	if resp.StatusCode != 200 {
   212  		log.Printf(
   213  			"Non-200 status code: %d. Getting error body.", resp.StatusCode)
   214  
   215  		errorBody := new(bytes.Buffer)
   216  		io.Copy(errorBody, resp.Body)
   217  		return fmt.Errorf("HTTP error '%d'! Remote side responded:\n%s",
   218  			resp.StatusCode, errorBody.String())
   219  	}
   220  
   221  	d.progress = 0
   222  	d.total = uint(resp.ContentLength)
   223  
   224  	var buffer [4096]byte
   225  	for {
   226  		n, err := resp.Body.Read(buffer[:])
   227  		if err != nil && err != io.EOF {
   228  			return err
   229  		}
   230  
   231  		d.progress += uint(n)
   232  
   233  		if _, werr := dst.Write(buffer[:n]); werr != nil {
   234  			return werr
   235  		}
   236  
   237  		if err == io.EOF {
   238  			break
   239  		}
   240  	}
   241  
   242  	return nil
   243  }
   244  
   245  func (d *HTTPDownloader) Progress() uint {
   246  	return d.progress
   247  }
   248  
   249  func (d *HTTPDownloader) Total() uint {
   250  	return d.total
   251  }