github.com/rothwerx/packer@v0.9.0/common/download.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"crypto/sha1"
     7  	"crypto/sha256"
     8  	"crypto/sha512"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"io"
    14  	"log"
    15  	"net/http"
    16  	"net/url"
    17  	"os"
    18  	"runtime"
    19  )
    20  
    21  // DownloadConfig is the configuration given to instantiate a new
    22  // download instance. Once a configuration is used to instantiate
    23  // a download client, it must not be modified.
    24  type DownloadConfig struct {
    25  	// The source URL in the form of a string.
    26  	Url string
    27  
    28  	// This is the path to download the file to.
    29  	TargetPath string
    30  
    31  	// DownloaderMap maps a schema to a Download.
    32  	DownloaderMap map[string]Downloader
    33  
    34  	// If true, this will copy even a local file to the target
    35  	// location. If false, then it will "download" the file by just
    36  	// returning the local path to the file.
    37  	CopyFile bool
    38  
    39  	// The hashing implementation to use to checksum the downloaded file.
    40  	Hash hash.Hash
    41  
    42  	// The checksum for the downloaded file. The hash implementation configuration
    43  	// for the downloader will be used to verify with this checksum after
    44  	// it is downloaded.
    45  	Checksum []byte
    46  
    47  	// What to use for the user agent for HTTP requests. If set to "", use the
    48  	// default user agent provided by Go.
    49  	UserAgent string
    50  }
    51  
    52  // A DownloadClient helps download, verify checksums, etc.
    53  type DownloadClient struct {
    54  	config     *DownloadConfig
    55  	downloader Downloader
    56  }
    57  
    58  // HashForType returns the Hash implementation for the given string
    59  // type, or nil if the type is not supported.
    60  func HashForType(t string) hash.Hash {
    61  	switch t {
    62  	case "md5":
    63  		return md5.New()
    64  	case "sha1":
    65  		return sha1.New()
    66  	case "sha256":
    67  		return sha256.New()
    68  	case "sha512":
    69  		return sha512.New()
    70  	default:
    71  		return nil
    72  	}
    73  }
    74  
    75  // NewDownloadClient returns a new DownloadClient for the given
    76  // configuration.
    77  func NewDownloadClient(c *DownloadConfig) *DownloadClient {
    78  	if c.DownloaderMap == nil {
    79  		c.DownloaderMap = map[string]Downloader{
    80  			"http":  &HTTPDownloader{userAgent: c.UserAgent},
    81  			"https": &HTTPDownloader{userAgent: c.UserAgent},
    82  		}
    83  	}
    84  
    85  	return &DownloadClient{config: c}
    86  }
    87  
    88  // A downloader is responsible for actually taking a remote URL and
    89  // downloading it.
    90  type Downloader interface {
    91  	Cancel()
    92  	Download(*os.File, *url.URL) error
    93  	Progress() uint
    94  	Total() uint
    95  }
    96  
    97  func (d *DownloadClient) Cancel() {
    98  	// TODO(mitchellh): Implement
    99  }
   100  
   101  func (d *DownloadClient) Get() (string, error) {
   102  	// If we already have the file and it matches, then just return the target path.
   103  	if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
   104  		log.Println("[DEBUG] Initial checksum matched, no download needed.")
   105  		return d.config.TargetPath, nil
   106  	}
   107  
   108  	url, err := url.Parse(d.config.Url)
   109  	if err != nil {
   110  		return "", err
   111  	}
   112  
   113  	log.Printf("Parsed URL: %#v", url)
   114  
   115  	// Files when we don't copy the file are special cased.
   116  	var f *os.File
   117  	var finalPath string
   118  	sourcePath := ""
   119  	if url.Scheme == "file" && !d.config.CopyFile {
   120  		// This is a special case where we use a source file that already exists
   121  		// locally and we don't make a copy. Normally we would copy or download.
   122  		finalPath = url.Path
   123  		log.Printf("[DEBUG] Using local file: %s", finalPath)
   124  
   125  		// Remove forward slash on absolute Windows file URLs before processing
   126  		if runtime.GOOS == "windows" && len(finalPath) > 0 && finalPath[0] == '/' {
   127  			finalPath = finalPath[1:len(finalPath)]
   128  		}
   129  		// Keep track of the source so we can make sure not to delete this later
   130  		sourcePath = finalPath
   131  	} else {
   132  		finalPath = d.config.TargetPath
   133  
   134  		var ok bool
   135  		d.downloader, ok = d.config.DownloaderMap[url.Scheme]
   136  		if !ok {
   137  			return "", fmt.Errorf("No downloader for scheme: %s", url.Scheme)
   138  		}
   139  
   140  		// Otherwise, download using the downloader.
   141  		f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666))
   142  		if err != nil {
   143  			return "", err
   144  		}
   145  
   146  		log.Printf("[DEBUG] Downloading: %s", url.String())
   147  		err = d.downloader.Download(f, url)
   148  		f.Close()
   149  		if err != nil {
   150  			return "", err
   151  		}
   152  	}
   153  
   154  	if d.config.Hash != nil {
   155  		var verify bool
   156  		verify, err = d.VerifyChecksum(finalPath)
   157  		if err == nil && !verify {
   158  			// Only delete the file if we made a copy or downloaded it
   159  			if sourcePath != finalPath {
   160  				os.Remove(finalPath)
   161  			}
   162  
   163  			err = fmt.Errorf(
   164  				"checksums didn't match expected: %s",
   165  				hex.EncodeToString(d.config.Checksum))
   166  		}
   167  	}
   168  
   169  	return finalPath, err
   170  }
   171  
   172  // PercentProgress returns the download progress as a percentage.
   173  func (d *DownloadClient) PercentProgress() int {
   174  	if d.downloader == nil {
   175  		return -1
   176  	}
   177  
   178  	return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100)
   179  }
   180  
   181  // VerifyChecksum tests that the path matches the checksum for the
   182  // download.
   183  func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
   184  	if d.config.Checksum == nil || d.config.Hash == nil {
   185  		return false, errors.New("Checksum or Hash isn't set on download.")
   186  	}
   187  
   188  	f, err := os.Open(path)
   189  	if err != nil {
   190  		return false, err
   191  	}
   192  	defer f.Close()
   193  
   194  	log.Printf("Verifying checksum of %s", path)
   195  	d.config.Hash.Reset()
   196  	io.Copy(d.config.Hash, f)
   197  	return bytes.Compare(d.config.Hash.Sum(nil), d.config.Checksum) == 0, nil
   198  }
   199  
   200  // HTTPDownloader is an implementation of Downloader that downloads
   201  // files over HTTP.
   202  type HTTPDownloader struct {
   203  	progress  uint
   204  	total     uint
   205  	userAgent string
   206  }
   207  
   208  func (*HTTPDownloader) Cancel() {
   209  	// TODO(mitchellh): Implement
   210  }
   211  
   212  func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
   213  	log.Printf("Starting download: %s", src.String())
   214  
   215  	// Seek to the beginning by default
   216  	if _, err := dst.Seek(0, 0); err != nil {
   217  		return err
   218  	}
   219  
   220  	// Reset our progress
   221  	d.progress = 0
   222  
   223  	// Make the request. We first make a HEAD request so we can check
   224  	// if the server supports range queries. If the server/URL doesn't
   225  	// support HEAD requests, we just fall back to GET.
   226  	req, err := http.NewRequest("HEAD", src.String(), nil)
   227  	if err != nil {
   228  		return err
   229  	}
   230  
   231  	if d.userAgent != "" {
   232  		req.Header.Set("User-Agent", d.userAgent)
   233  	}
   234  
   235  	httpClient := &http.Client{
   236  		Transport: &http.Transport{
   237  			Proxy: http.ProxyFromEnvironment,
   238  		},
   239  	}
   240  
   241  	resp, err := httpClient.Do(req)
   242  	if err == nil && (resp.StatusCode >= 200 && resp.StatusCode < 300) {
   243  		// If the HEAD request succeeded, then attempt to set the range
   244  		// query if we can.
   245  		if resp.Header.Get("Accept-Ranges") == "bytes" {
   246  			if fi, err := dst.Stat(); err == nil {
   247  				if _, err = dst.Seek(0, os.SEEK_END); err == nil {
   248  					req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))
   249  					d.progress = uint(fi.Size())
   250  				}
   251  			}
   252  		}
   253  	}
   254  
   255  	// Set the request to GET now, and redo the query to download
   256  	req.Method = "GET"
   257  
   258  	resp, err = httpClient.Do(req)
   259  	if err != nil {
   260  		return err
   261  	}
   262  
   263  	d.total = d.progress + uint(resp.ContentLength)
   264  	var buffer [4096]byte
   265  	for {
   266  		n, err := resp.Body.Read(buffer[:])
   267  		if err != nil && err != io.EOF {
   268  			return err
   269  		}
   270  
   271  		d.progress += uint(n)
   272  
   273  		if _, werr := dst.Write(buffer[:n]); werr != nil {
   274  			return werr
   275  		}
   276  
   277  		if err == io.EOF {
   278  			break
   279  		}
   280  	}
   281  
   282  	return nil
   283  }
   284  
   285  func (d *HTTPDownloader) Progress() uint {
   286  	return d.progress
   287  }
   288  
   289  func (d *HTTPDownloader) Total() uint {
   290  	return d.total
   291  }