github.com/phobos182/packer@v0.2.3-0.20130819023704-c84d2aeffc68/common/download.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"crypto/sha1"
     7  	"crypto/sha256"
     8  	"encoding/hex"
     9  	"errors"
    10  	"fmt"
    11  	"hash"
    12  	"io"
    13  	"log"
    14  	"net/http"
    15  	"net/url"
    16  	"os"
    17  	"runtime"
    18  )
    19  
    20  // DownloadConfig is the configuration given to instantiate a new
    21  // download instance. Once a configuration is used to instantiate
    22  // a download client, it must not be modified.
    23  type DownloadConfig struct {
    24  	// The source URL in the form of a string.
    25  	Url string
    26  
    27  	// This is the path to download the file to.
    28  	TargetPath string
    29  
    30  	// DownloaderMap maps a schema to a Download.
    31  	DownloaderMap map[string]Downloader
    32  
    33  	// If true, this will copy even a local file to the target
    34  	// location. If false, then it will "download" the file by just
    35  	// returning the local path to the file.
    36  	CopyFile bool
    37  
    38  	// The hashing implementation to use to checksum the downloaded file.
    39  	Hash hash.Hash
    40  
    41  	// The checksum for the downloaded file. The hash implementation configuration
    42  	// for the downloader will be used to verify with this checksum after
    43  	// it is downloaded.
    44  	Checksum []byte
    45  }
    46  
    47  // A DownloadClient helps download, verify checksums, etc.
    48  type DownloadClient struct {
    49  	config     *DownloadConfig
    50  	downloader Downloader
    51  }
    52  
    53  // HashForType returns the Hash implementation for the given string
    54  // type, or nil if the type is not supported.
    55  func HashForType(t string) hash.Hash {
    56  	switch t {
    57  	case "md5":
    58  		return md5.New()
    59  	case "sha1":
    60  		return sha1.New()
    61  	case "sha256":
    62  		return sha256.New()
    63  	default:
    64  		return nil
    65  	}
    66  }
    67  
    68  // NewDownloadClient returns a new DownloadClient for the given
    69  // configuration.
    70  func NewDownloadClient(c *DownloadConfig) *DownloadClient {
    71  	if c.DownloaderMap == nil {
    72  		c.DownloaderMap = map[string]Downloader{
    73  			"http": new(HTTPDownloader),
    74  		}
    75  	}
    76  
    77  	return &DownloadClient{config: c}
    78  }
    79  
    80  // A downloader is responsible for actually taking a remote URL and
    81  // downloading it.
    82  type Downloader interface {
    83  	Cancel()
    84  	Download(io.Writer, *url.URL) error
    85  	Progress() uint
    86  	Total() uint
    87  }
    88  
    89  func (d *DownloadClient) Cancel() {
    90  	// TODO(mitchellh): Implement
    91  }
    92  
    93  func (d *DownloadClient) Get() (string, error) {
    94  	// If we already have the file and it matches, then just return the target path.
    95  	if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
    96  		log.Println("Initial checksum matched, no download needed.")
    97  		return d.config.TargetPath, nil
    98  	}
    99  
   100  	url, err := url.Parse(d.config.Url)
   101  	if err != nil {
   102  		return "", err
   103  	}
   104  
   105  	log.Printf("Parsed URL: %#v", url)
   106  
   107  	// Files when we don't copy the file are special cased.
   108  	var finalPath string
   109  	if url.Scheme == "file" && !d.config.CopyFile {
   110  		finalPath = url.Path
   111  
   112  		// Remove forward slash on absolute Windows file URLs before processing
   113  		if runtime.GOOS == "windows" && finalPath[0] == '/' {
   114  			finalPath = finalPath[1:len(finalPath)]
   115  		}
   116  	} else {
   117  		finalPath = d.config.TargetPath
   118  
   119  		var ok bool
   120  		d.downloader, ok = d.config.DownloaderMap[url.Scheme]
   121  		if !ok {
   122  			return "", fmt.Errorf("No downloader for scheme: %s", url.Scheme)
   123  		}
   124  
   125  		// Otherwise, download using the downloader.
   126  		f, err := os.Create(finalPath)
   127  		if err != nil {
   128  			return "", err
   129  		}
   130  		defer f.Close()
   131  
   132  		log.Printf("Downloading: %s", url.String())
   133  		err = d.downloader.Download(f, url)
   134  		if err != nil {
   135  			return "", err
   136  		}
   137  	}
   138  
   139  	if d.config.Hash != nil {
   140  		var verify bool
   141  		verify, err = d.VerifyChecksum(finalPath)
   142  		if err == nil && !verify {
   143  			err = fmt.Errorf("checksums didn't match expected: %s", hex.EncodeToString(d.config.Checksum))
   144  		}
   145  	}
   146  
   147  	return finalPath, err
   148  }
   149  
   150  // PercentProgress returns the download progress as a percentage.
   151  func (d *DownloadClient) PercentProgress() int {
   152  	if d.downloader == nil {
   153  		return -1
   154  	}
   155  
   156  	return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100)
   157  }
   158  
   159  // VerifyChecksum tests that the path matches the checksum for the
   160  // download.
   161  func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
   162  	if d.config.Checksum == nil || d.config.Hash == nil {
   163  		return false, errors.New("Checksum or Hash isn't set on download.")
   164  	}
   165  
   166  	f, err := os.Open(path)
   167  	if err != nil {
   168  		return false, err
   169  	}
   170  	defer f.Close()
   171  
   172  	log.Printf("Verifying checksum of %s", path)
   173  	d.config.Hash.Reset()
   174  	io.Copy(d.config.Hash, f)
   175  	return bytes.Compare(d.config.Hash.Sum(nil), d.config.Checksum) == 0, nil
   176  }
   177  
   178  // HTTPDownloader is an implementation of Downloader that downloads
   179  // files over HTTP.
   180  type HTTPDownloader struct {
   181  	progress uint
   182  	total    uint
   183  }
   184  
   185  func (*HTTPDownloader) Cancel() {
   186  	// TODO(mitchellh): Implement
   187  }
   188  
   189  func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error {
   190  	log.Printf("Starting download: %s", src.String())
   191  	req, err := http.NewRequest("GET", src.String(), nil)
   192  	if err != nil {
   193  		return err
   194  	}
   195  
   196  	httpClient := &http.Client{
   197  		Transport: &http.Transport{
   198  			Proxy: http.ProxyFromEnvironment,
   199  		},
   200  	}
   201  
   202  	resp, err := httpClient.Do(req)
   203  	if err != nil {
   204  		return err
   205  	}
   206  
   207  	if resp.StatusCode != 200 {
   208  		log.Printf(
   209  			"Non-200 status code: %d. Getting error body.", resp.StatusCode)
   210  
   211  		errorBody := new(bytes.Buffer)
   212  		io.Copy(errorBody, resp.Body)
   213  		return fmt.Errorf("HTTP error '%d'! Remote side responded:\n%s",
   214  			resp.StatusCode, errorBody.String())
   215  	}
   216  
   217  	d.progress = 0
   218  	d.total = uint(resp.ContentLength)
   219  
   220  	var buffer [4096]byte
   221  	for {
   222  		n, err := resp.Body.Read(buffer[:])
   223  		if err != nil && err != io.EOF {
   224  			return err
   225  		}
   226  
   227  		d.progress += uint(n)
   228  
   229  		if _, werr := dst.Write(buffer[:n]); werr != nil {
   230  			return werr
   231  		}
   232  
   233  		if err == io.EOF {
   234  			break
   235  		}
   236  	}
   237  
   238  	return nil
   239  }
   240  
   241  func (d *HTTPDownloader) Progress() uint {
   242  	return d.progress
   243  }
   244  
   245  func (d *HTTPDownloader) Total() uint {
   246  	return d.total
   247  }