github.com/amanya/packer@v0.12.1-0.20161117214323-902ac5ab2eb6/common/download.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"crypto/sha1"
     7  	"crypto/sha256"
     8  	"crypto/sha512"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"io"
    14  	"log"
    15  	"net/http"
    16  	"net/url"
    17  	"os"
    18  	"runtime"
    19  )
    20  
    21  // DownloadConfig is the configuration given to instantiate a new
    22  // download instance. Once a configuration is used to instantiate
    23  // a download client, it must not be modified.
    24  type DownloadConfig struct {
    25  	// The source URL in the form of a string.
    26  	Url string
    27  
    28  	// This is the path to download the file to.
    29  	TargetPath string
    30  
    31  	// DownloaderMap maps a schema to a Download.
    32  	DownloaderMap map[string]Downloader
    33  
    34  	// If true, this will copy even a local file to the target
    35  	// location. If false, then it will "download" the file by just
    36  	// returning the local path to the file.
    37  	CopyFile bool
    38  
    39  	// The hashing implementation to use to checksum the downloaded file.
    40  	Hash hash.Hash
    41  
    42  	// The checksum for the downloaded file. The hash implementation configuration
    43  	// for the downloader will be used to verify with this checksum after
    44  	// it is downloaded.
    45  	Checksum []byte
    46  
    47  	// What to use for the user agent for HTTP requests. If set to "", use the
    48  	// default user agent provided by Go.
    49  	UserAgent string
    50  }
    51  
    52  // A DownloadClient helps download, verify checksums, etc.
    53  type DownloadClient struct {
    54  	config     *DownloadConfig
    55  	downloader Downloader
    56  }
    57  
    58  // HashForType returns the Hash implementation for the given string
    59  // type, or nil if the type is not supported.
    60  func HashForType(t string) hash.Hash {
    61  	switch t {
    62  	case "md5":
    63  		return md5.New()
    64  	case "sha1":
    65  		return sha1.New()
    66  	case "sha256":
    67  		return sha256.New()
    68  	case "sha512":
    69  		return sha512.New()
    70  	default:
    71  		return nil
    72  	}
    73  }
    74  
    75  // NewDownloadClient returns a new DownloadClient for the given
    76  // configuration.
    77  func NewDownloadClient(c *DownloadConfig) *DownloadClient {
    78  	if c.DownloaderMap == nil {
    79  		c.DownloaderMap = map[string]Downloader{
    80  			"http":  &HTTPDownloader{userAgent: c.UserAgent},
    81  			"https": &HTTPDownloader{userAgent: c.UserAgent},
    82  		}
    83  	}
    84  
    85  	return &DownloadClient{config: c}
    86  }
    87  
    88  // A downloader is responsible for actually taking a remote URL and
    89  // downloading it.
    90  type Downloader interface {
    91  	Cancel()
    92  	Download(*os.File, *url.URL) error
    93  	Progress() uint
    94  	Total() uint
    95  }
    96  
    97  func (d *DownloadClient) Cancel() {
    98  	// TODO(mitchellh): Implement
    99  }
   100  
   101  func (d *DownloadClient) Get() (string, error) {
   102  	// If we already have the file and it matches, then just return the target path.
   103  	if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
   104  		log.Println("[DEBUG] Initial checksum matched, no download needed.")
   105  		return d.config.TargetPath, nil
   106  	}
   107  
   108  	u, err := url.Parse(d.config.Url)
   109  	if err != nil {
   110  		return "", err
   111  	}
   112  
   113  	log.Printf("Parsed URL: %#v", u)
   114  
   115  	// Files when we don't copy the file are special cased.
   116  	var f *os.File
   117  	var finalPath string
   118  	sourcePath := ""
   119  	if u.Scheme == "file" && !d.config.CopyFile {
   120  		// This is special case for relative path in this case user specify
   121  		// file:../ and after parse destination goes to Opaque
   122  		if u.Path != "" {
   123  			// If url.Path is set just use this
   124  			finalPath = u.Path
   125  		} else if u.Opaque != "" {
   126  			// otherwise try url.Opaque
   127  			finalPath = u.Opaque
   128  		}
   129  		// This is a special case where we use a source file that already exists
   130  		// locally and we don't make a copy. Normally we would copy or download.
   131  		log.Printf("[DEBUG] Using local file: %s", finalPath)
   132  
   133  		// Remove forward slash on absolute Windows file URLs before processing
   134  		if runtime.GOOS == "windows" && len(finalPath) > 0 && finalPath[0] == '/' {
   135  			finalPath = finalPath[1:]
   136  		}
   137  		// Keep track of the source so we can make sure not to delete this later
   138  		sourcePath = finalPath
   139  		if _, err = os.Stat(finalPath); err != nil {
   140  			return "", err
   141  		}
   142  	} else {
   143  		finalPath = d.config.TargetPath
   144  
   145  		var ok bool
   146  		d.downloader, ok = d.config.DownloaderMap[u.Scheme]
   147  		if !ok {
   148  			return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme)
   149  		}
   150  
   151  		// Otherwise, download using the downloader.
   152  		f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666))
   153  		if err != nil {
   154  			return "", err
   155  		}
   156  
   157  		log.Printf("[DEBUG] Downloading: %s", u.String())
   158  		err = d.downloader.Download(f, u)
   159  		f.Close()
   160  		if err != nil {
   161  			return "", err
   162  		}
   163  	}
   164  
   165  	if d.config.Hash != nil {
   166  		var verify bool
   167  		verify, err = d.VerifyChecksum(finalPath)
   168  		if err == nil && !verify {
   169  			// Only delete the file if we made a copy or downloaded it
   170  			if sourcePath != finalPath {
   171  				os.Remove(finalPath)
   172  			}
   173  
   174  			err = fmt.Errorf(
   175  				"checksums didn't match expected: %s",
   176  				hex.EncodeToString(d.config.Checksum))
   177  		}
   178  	}
   179  
   180  	return finalPath, err
   181  }
   182  
   183  // PercentProgress returns the download progress as a percentage.
   184  func (d *DownloadClient) PercentProgress() int {
   185  	if d.downloader == nil {
   186  		return -1
   187  	}
   188  
   189  	return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100)
   190  }
   191  
   192  // VerifyChecksum tests that the path matches the checksum for the
   193  // download.
   194  func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
   195  	if d.config.Checksum == nil || d.config.Hash == nil {
   196  		return false, errors.New("Checksum or Hash isn't set on download.")
   197  	}
   198  
   199  	f, err := os.Open(path)
   200  	if err != nil {
   201  		return false, err
   202  	}
   203  	defer f.Close()
   204  
   205  	log.Printf("Verifying checksum of %s", path)
   206  	d.config.Hash.Reset()
   207  	io.Copy(d.config.Hash, f)
   208  	return bytes.Compare(d.config.Hash.Sum(nil), d.config.Checksum) == 0, nil
   209  }
   210  
   211  // HTTPDownloader is an implementation of Downloader that downloads
   212  // files over HTTP.
   213  type HTTPDownloader struct {
   214  	progress  uint
   215  	total     uint
   216  	userAgent string
   217  }
   218  
   219  func (*HTTPDownloader) Cancel() {
   220  	// TODO(mitchellh): Implement
   221  }
   222  
   223  func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
   224  	log.Printf("Starting download: %s", src.String())
   225  
   226  	// Seek to the beginning by default
   227  	if _, err := dst.Seek(0, 0); err != nil {
   228  		return err
   229  	}
   230  
   231  	// Reset our progress
   232  	d.progress = 0
   233  
   234  	// Make the request. We first make a HEAD request so we can check
   235  	// if the server supports range queries. If the server/URL doesn't
   236  	// support HEAD requests, we just fall back to GET.
   237  	req, err := http.NewRequest("HEAD", src.String(), nil)
   238  	if err != nil {
   239  		return err
   240  	}
   241  
   242  	if d.userAgent != "" {
   243  		req.Header.Set("User-Agent", d.userAgent)
   244  	}
   245  
   246  	httpClient := &http.Client{
   247  		Transport: &http.Transport{
   248  			Proxy: http.ProxyFromEnvironment,
   249  		},
   250  	}
   251  
   252  	resp, err := httpClient.Do(req)
   253  	if err == nil && (resp.StatusCode >= 200 && resp.StatusCode < 300) {
   254  		// If the HEAD request succeeded, then attempt to set the range
   255  		// query if we can.
   256  		if resp.Header.Get("Accept-Ranges") == "bytes" {
   257  			if fi, err := dst.Stat(); err == nil {
   258  				if _, err = dst.Seek(0, os.SEEK_END); err == nil {
   259  					req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))
   260  					d.progress = uint(fi.Size())
   261  				}
   262  			}
   263  		}
   264  	}
   265  
   266  	// Set the request to GET now, and redo the query to download
   267  	req.Method = "GET"
   268  
   269  	resp, err = httpClient.Do(req)
   270  	if err != nil {
   271  		return err
   272  	}
   273  
   274  	d.total = d.progress + uint(resp.ContentLength)
   275  	var buffer [4096]byte
   276  	for {
   277  		n, err := resp.Body.Read(buffer[:])
   278  		if err != nil && err != io.EOF {
   279  			return err
   280  		}
   281  
   282  		d.progress += uint(n)
   283  
   284  		if _, werr := dst.Write(buffer[:n]); werr != nil {
   285  			return werr
   286  		}
   287  
   288  		if err == io.EOF {
   289  			break
   290  		}
   291  	}
   292  
   293  	return nil
   294  }
   295  
   296  func (d *HTTPDownloader) Progress() uint {
   297  	return d.progress
   298  }
   299  
   300  func (d *HTTPDownloader) Total() uint {
   301  	return d.total
   302  }