github.com/rsyabuta/packer@v1.1.4-0.20180119234903-5ef0c2280f0b/common/download.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"crypto/sha1"
     7  	"crypto/sha256"
     8  	"crypto/sha512"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"io"
    14  	"log"
    15  	"net/http"
    16  	"net/url"
    17  	"os"
    18  	"runtime"
    19  )
    20  
    21  // DownloadConfig is the configuration given to instantiate a new
    22  // download instance. Once a configuration is used to instantiate
    23  // a download client, it must not be modified.
    24  type DownloadConfig struct {
    25  	// The source URL in the form of a string.
    26  	Url string
    27  
    28  	// This is the path to download the file to.
    29  	TargetPath string
    30  
    31  	// DownloaderMap maps a schema to a Download.
    32  	DownloaderMap map[string]Downloader
    33  
    34  	// If true, this will copy even a local file to the target
    35  	// location. If false, then it will "download" the file by just
    36  	// returning the local path to the file.
    37  	CopyFile bool
    38  
    39  	// The hashing implementation to use to checksum the downloaded file.
    40  	Hash hash.Hash
    41  
    42  	// The checksum for the downloaded file. The hash implementation configuration
    43  	// for the downloader will be used to verify with this checksum after
    44  	// it is downloaded.
    45  	Checksum []byte
    46  
    47  	// What to use for the user agent for HTTP requests. If set to "", use the
    48  	// default user agent provided by Go.
    49  	UserAgent string
    50  }
    51  
    52  // A DownloadClient helps download, verify checksums, etc.
    53  type DownloadClient struct {
    54  	config     *DownloadConfig
    55  	downloader Downloader
    56  }
    57  
    58  // HashForType returns the Hash implementation for the given string
    59  // type, or nil if the type is not supported.
    60  func HashForType(t string) hash.Hash {
    61  	switch t {
    62  	case "md5":
    63  		return md5.New()
    64  	case "sha1":
    65  		return sha1.New()
    66  	case "sha256":
    67  		return sha256.New()
    68  	case "sha512":
    69  		return sha512.New()
    70  	default:
    71  		return nil
    72  	}
    73  }
    74  
    75  // NewDownloadClient returns a new DownloadClient for the given
    76  // configuration.
    77  func NewDownloadClient(c *DownloadConfig) *DownloadClient {
    78  	if c.DownloaderMap == nil {
    79  		c.DownloaderMap = map[string]Downloader{
    80  			"http":  &HTTPDownloader{userAgent: c.UserAgent},
    81  			"https": &HTTPDownloader{userAgent: c.UserAgent},
    82  		}
    83  	}
    84  
    85  	return &DownloadClient{config: c}
    86  }
    87  
    88  // A downloader is responsible for actually taking a remote URL and
    89  // downloading it.
    90  type Downloader interface {
    91  	Cancel()
    92  	Download(*os.File, *url.URL) error
    93  	Progress() uint
    94  	Total() uint
    95  }
    96  
    97  func (d *DownloadClient) Cancel() {
    98  	// TODO(mitchellh): Implement
    99  }
   100  
   101  func (d *DownloadClient) Get() (string, error) {
   102  	// If we already have the file and it matches, then just return the target path.
   103  	if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
   104  		log.Println("[DEBUG] Initial checksum matched, no download needed.")
   105  		return d.config.TargetPath, nil
   106  	}
   107  
   108  	u, err := url.Parse(d.config.Url)
   109  	if err != nil {
   110  		return "", err
   111  	}
   112  
   113  	log.Printf("Parsed URL: %#v", u)
   114  
   115  	// Files when we don't copy the file are special cased.
   116  	var f *os.File
   117  	var finalPath string
   118  	sourcePath := ""
   119  	if u.Scheme == "file" && !d.config.CopyFile {
   120  		// This is special case for relative path in this case user specify
   121  		// file:../ and after parse destination goes to Opaque
   122  		if u.Path != "" {
   123  			// If url.Path is set just use this
   124  			finalPath = u.Path
   125  		} else if u.Opaque != "" {
   126  			// otherwise try url.Opaque
   127  			finalPath = u.Opaque
   128  		}
   129  		// This is a special case where we use a source file that already exists
   130  		// locally and we don't make a copy. Normally we would copy or download.
   131  		log.Printf("[DEBUG] Using local file: %s", finalPath)
   132  
   133  		// Remove forward slash on absolute Windows file URLs before processing
   134  		if runtime.GOOS == "windows" && len(finalPath) > 0 && finalPath[0] == '/' {
   135  			finalPath = finalPath[1:]
   136  		}
   137  
   138  		// Keep track of the source so we can make sure not to delete this later
   139  		sourcePath = finalPath
   140  		if _, err = os.Stat(finalPath); err != nil {
   141  			return "", err
   142  		}
   143  	} else {
   144  		finalPath = d.config.TargetPath
   145  
   146  		var ok bool
   147  		d.downloader, ok = d.config.DownloaderMap[u.Scheme]
   148  		if !ok {
   149  			return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme)
   150  		}
   151  
   152  		// Otherwise, download using the downloader.
   153  		f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666))
   154  		if err != nil {
   155  			return "", err
   156  		}
   157  
   158  		log.Printf("[DEBUG] Downloading: %s", u.String())
   159  		err = d.downloader.Download(f, u)
   160  		f.Close()
   161  		if err != nil {
   162  			return "", err
   163  		}
   164  	}
   165  
   166  	if d.config.Hash != nil {
   167  		var verify bool
   168  		verify, err = d.VerifyChecksum(finalPath)
   169  		if err == nil && !verify {
   170  			// Only delete the file if we made a copy or downloaded it
   171  			if sourcePath != finalPath {
   172  				os.Remove(finalPath)
   173  			}
   174  
   175  			err = fmt.Errorf(
   176  				"checksums didn't match expected: %s",
   177  				hex.EncodeToString(d.config.Checksum))
   178  		}
   179  	}
   180  
   181  	return finalPath, err
   182  }
   183  
   184  // PercentProgress returns the download progress as a percentage.
   185  func (d *DownloadClient) PercentProgress() int {
   186  	if d.downloader == nil {
   187  		return -1
   188  	}
   189  
   190  	return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100)
   191  }
   192  
   193  // VerifyChecksum tests that the path matches the checksum for the
   194  // download.
   195  func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
   196  	if d.config.Checksum == nil || d.config.Hash == nil {
   197  		return false, errors.New("Checksum or Hash isn't set on download.")
   198  	}
   199  
   200  	f, err := os.Open(path)
   201  	if err != nil {
   202  		return false, err
   203  	}
   204  	defer f.Close()
   205  
   206  	log.Printf("Verifying checksum of %s", path)
   207  	d.config.Hash.Reset()
   208  	io.Copy(d.config.Hash, f)
   209  	return bytes.Equal(d.config.Hash.Sum(nil), d.config.Checksum), nil
   210  }
   211  
   212  // HTTPDownloader is an implementation of Downloader that downloads
   213  // files over HTTP.
   214  type HTTPDownloader struct {
   215  	progress  uint
   216  	total     uint
   217  	userAgent string
   218  }
   219  
   220  func (*HTTPDownloader) Cancel() {
   221  	// TODO(mitchellh): Implement
   222  }
   223  
   224  func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
   225  	log.Printf("Starting download: %s", src.String())
   226  
   227  	// Seek to the beginning by default
   228  	if _, err := dst.Seek(0, 0); err != nil {
   229  		return err
   230  	}
   231  
   232  	// Reset our progress
   233  	d.progress = 0
   234  
   235  	// Make the request. We first make a HEAD request so we can check
   236  	// if the server supports range queries. If the server/URL doesn't
   237  	// support HEAD requests, we just fall back to GET.
   238  	req, err := http.NewRequest("HEAD", src.String(), nil)
   239  	if err != nil {
   240  		return err
   241  	}
   242  
   243  	if d.userAgent != "" {
   244  		req.Header.Set("User-Agent", d.userAgent)
   245  	}
   246  
   247  	httpClient := &http.Client{
   248  		Transport: &http.Transport{
   249  			Proxy: http.ProxyFromEnvironment,
   250  		},
   251  	}
   252  
   253  	resp, err := httpClient.Do(req)
   254  	if err == nil && (resp.StatusCode >= 200 && resp.StatusCode < 300) {
   255  		// If the HEAD request succeeded, then attempt to set the range
   256  		// query if we can.
   257  		if resp.Header.Get("Accept-Ranges") == "bytes" {
   258  			if fi, err := dst.Stat(); err == nil {
   259  				if _, err = dst.Seek(0, os.SEEK_END); err == nil {
   260  					req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))
   261  					d.progress = uint(fi.Size())
   262  				}
   263  			}
   264  		}
   265  	}
   266  
   267  	// Set the request to GET now, and redo the query to download
   268  	req.Method = "GET"
   269  
   270  	resp, err = httpClient.Do(req)
   271  	if err != nil {
   272  		return err
   273  	}
   274  
   275  	d.total = d.progress + uint(resp.ContentLength)
   276  	var buffer [4096]byte
   277  	for {
   278  		n, err := resp.Body.Read(buffer[:])
   279  		if err != nil && err != io.EOF {
   280  			return err
   281  		}
   282  
   283  		d.progress += uint(n)
   284  
   285  		if _, werr := dst.Write(buffer[:n]); werr != nil {
   286  			return werr
   287  		}
   288  
   289  		if err == io.EOF {
   290  			break
   291  		}
   292  	}
   293  
   294  	return nil
   295  }
   296  
   297  func (d *HTTPDownloader) Progress() uint {
   298  	return d.progress
   299  }
   300  
   301  func (d *HTTPDownloader) Total() uint {
   302  	return d.total
   303  }