github.com/StackPointCloud/packer@v0.10.2-0.20180716202532-b28098e0f79b/common/download.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"crypto/sha1"
     7  	"crypto/sha256"
     8  	"crypto/sha512"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"log"
    14  	"net/url"
    15  	"os"
    16  	"path"
    17  	"runtime"
    18  	"strings"
    19  )
    20  
    21  // imports related to each Downloader implementation
    22  import (
    23  	"io"
    24  	"net/http"
    25  	"path/filepath"
    26  )
    27  
    28  // DownloadConfig is the configuration given to instantiate a new
    29  // download instance. Once a configuration is used to instantiate
    30  // a download client, it must not be modified.
    31  type DownloadConfig struct {
    32  	// The source URL in the form of a string.
    33  	Url string
    34  
    35  	// This is the path to download the file to.
    36  	TargetPath string
    37  
    38  	// DownloaderMap maps a schema to a Download.
    39  	DownloaderMap map[string]Downloader
    40  
    41  	// If true, this will copy even a local file to the target
    42  	// location. If false, then it will "download" the file by just
    43  	// returning the local path to the file.
    44  	CopyFile bool
    45  
    46  	// The hashing implementation to use to checksum the downloaded file.
    47  	Hash hash.Hash
    48  
    49  	// The checksum for the downloaded file. The hash implementation configuration
    50  	// for the downloader will be used to verify with this checksum after
    51  	// it is downloaded.
    52  	Checksum []byte
    53  
    54  	// What to use for the user agent for HTTP requests. If set to "", use the
    55  	// default user agent provided by Go.
    56  	UserAgent string
    57  }
    58  
    59  // A DownloadClient helps download, verify checksums, etc.
    60  type DownloadClient struct {
    61  	config     *DownloadConfig
    62  	downloader Downloader
    63  }
    64  
    65  // HashForType returns the Hash implementation for the given string
    66  // type, or nil if the type is not supported.
    67  func HashForType(t string) hash.Hash {
    68  	switch t {
    69  	case "md5":
    70  		return md5.New()
    71  	case "sha1":
    72  		return sha1.New()
    73  	case "sha256":
    74  		return sha256.New()
    75  	case "sha512":
    76  		return sha512.New()
    77  	default:
    78  		return nil
    79  	}
    80  }
    81  
    82  // NewDownloadClient returns a new DownloadClient for the given
    83  // configuration.
    84  func NewDownloadClient(c *DownloadConfig) *DownloadClient {
    85  	const mtu = 1500 /* ethernet */ - 20 /* ipv4 */ - 20 /* tcp */
    86  
    87  	// Create downloader map if it hasn't been specified already.
    88  	if c.DownloaderMap == nil {
    89  		c.DownloaderMap = map[string]Downloader{
    90  			"file":  &FileDownloader{bufferSize: nil},
    91  			"http":  &HTTPDownloader{userAgent: c.UserAgent},
    92  			"https": &HTTPDownloader{userAgent: c.UserAgent},
    93  			"smb":   &SMBDownloader{bufferSize: nil},
    94  		}
    95  	}
    96  	return &DownloadClient{config: c}
    97  }
    98  
    99  // A downloader implements the ability to transfer, cancel, or resume a file.
   100  type Downloader interface {
   101  	Resume()
   102  	Cancel()
   103  	Progress() uint64
   104  	Total() uint64
   105  }
   106  
   107  // A LocalDownloader is responsible for converting a uri to a local path
   108  //	that the platform can open directly.
   109  type LocalDownloader interface {
   110  	toPath(string, url.URL) (string, error)
   111  }
   112  
   113  // A RemoteDownloader is responsible for actually taking a remote URL and
   114  //	downloading it.
   115  type RemoteDownloader interface {
   116  	Download(*os.File, *url.URL) error
   117  }
   118  
   119  func (d *DownloadClient) Cancel() {
   120  	// TODO(mitchellh): Implement
   121  }
   122  
   123  func (d *DownloadClient) Get() (string, error) {
   124  	// If we already have the file and it matches, then just return the target path.
   125  	if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
   126  		log.Println("[DEBUG] Initial checksum matched, no download needed.")
   127  		return d.config.TargetPath, nil
   128  	}
   129  
   130  	/* parse the configuration url into a net/url object */
   131  	u, err := url.Parse(d.config.Url)
   132  	if err != nil {
   133  		return "", err
   134  	}
   135  	log.Printf("Parsed URL: %#v", u)
   136  
   137  	/* use the current working directory as the base for relative uri's */
   138  	cwd, err := os.Getwd()
   139  	if err != nil {
   140  		return "", err
   141  	}
   142  
   143  	// Determine which is the correct downloader to use
   144  	var finalPath string
   145  
   146  	var ok bool
   147  	d.downloader, ok = d.config.DownloaderMap[u.Scheme]
   148  	if !ok {
   149  		return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme)
   150  	}
   151  
   152  	remote, ok := d.downloader.(RemoteDownloader)
   153  	if !ok {
   154  		return "", fmt.Errorf("Unable to treat uri scheme %s as a Downloader. : %T", u.Scheme, d.downloader)
   155  	}
   156  
   157  	local, ok := d.downloader.(LocalDownloader)
   158  	if !ok && !d.config.CopyFile {
   159  		d.config.CopyFile = true
   160  	}
   161  
   162  	// If we're copying the file, then just use the actual downloader
   163  	if d.config.CopyFile {
   164  		var f *os.File
   165  		finalPath = d.config.TargetPath
   166  
   167  		f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666))
   168  		if err != nil {
   169  			return "", err
   170  		}
   171  
   172  		log.Printf("[DEBUG] Downloading: %s", u.String())
   173  		err = remote.Download(f, u)
   174  		f.Close()
   175  		if err != nil {
   176  			return "", err
   177  		}
   178  
   179  		// Otherwise if our Downloader is a LocalDownloader we can just use the
   180  		//	path after transforming it.
   181  	} else {
   182  		finalPath, err = local.toPath(cwd, *u)
   183  		if err != nil {
   184  			return "", err
   185  		}
   186  
   187  		log.Printf("[DEBUG] Using local file: %s", finalPath)
   188  	}
   189  
   190  	if d.config.Hash != nil {
   191  		var verify bool
   192  		verify, err = d.VerifyChecksum(finalPath)
   193  		if err == nil && !verify {
   194  			// Only delete the file if we made a copy or downloaded it
   195  			if d.config.CopyFile {
   196  				os.Remove(finalPath)
   197  			}
   198  
   199  			err = fmt.Errorf(
   200  				"checksums didn't match expected: %s",
   201  				hex.EncodeToString(d.config.Checksum))
   202  		}
   203  	}
   204  
   205  	return finalPath, err
   206  }
   207  
   208  func (d *DownloadClient) PercentProgress() int {
   209  	if d.downloader == nil {
   210  		return -1
   211  	}
   212  
   213  	return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100)
   214  }
   215  
   216  // VerifyChecksum tests that the path matches the checksum for the
   217  // download.
   218  func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
   219  	if d.config.Checksum == nil || d.config.Hash == nil {
   220  		return false, errors.New("Checksum or Hash isn't set on download.")
   221  	}
   222  
   223  	f, err := os.Open(path)
   224  	if err != nil {
   225  		return false, err
   226  	}
   227  	defer f.Close()
   228  
   229  	log.Printf("Verifying checksum of %s", path)
   230  	d.config.Hash.Reset()
   231  	io.Copy(d.config.Hash, f)
   232  	return bytes.Equal(d.config.Hash.Sum(nil), d.config.Checksum), nil
   233  }
   234  
   235  // HTTPDownloader is an implementation of Downloader that downloads
   236  // files over HTTP.
   237  type HTTPDownloader struct {
   238  	current   uint64
   239  	total     uint64
   240  	userAgent string
   241  }
   242  
   243  func (d *HTTPDownloader) Cancel() {
   244  	// TODO(mitchellh): Implement
   245  }
   246  
   247  func (d *HTTPDownloader) Resume() {
   248  	// TODO(mitchellh): Implement
   249  }
   250  
   251  func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
   252  	log.Printf("Starting download over HTTP: %s", src.String())
   253  
   254  	// Seek to the beginning by default
   255  	if _, err := dst.Seek(0, 0); err != nil {
   256  		return err
   257  	}
   258  
   259  	// Reset our progress
   260  	d.current = 0
   261  
   262  	// Make the request. We first make a HEAD request so we can check
   263  	// if the server supports range queries. If the server/URL doesn't
   264  	// support HEAD requests, we just fall back to GET.
   265  	req, err := http.NewRequest("HEAD", src.String(), nil)
   266  	if err != nil {
   267  		return err
   268  	}
   269  
   270  	if d.userAgent != "" {
   271  		req.Header.Set("User-Agent", d.userAgent)
   272  	}
   273  
   274  	httpClient := &http.Client{
   275  		Transport: &http.Transport{
   276  			Proxy: http.ProxyFromEnvironment,
   277  		},
   278  	}
   279  
   280  	resp, err := httpClient.Do(req)
   281  	if err != nil {
   282  		log.Printf("[DEBUG] (download) Error making HTTP HEAD request: %s", err.Error())
   283  	} else {
   284  		if resp.StatusCode >= 200 && resp.StatusCode < 300 {
   285  			// If the HEAD request succeeded, then attempt to set the range
   286  			// query if we can.
   287  			if resp.Header.Get("Accept-Ranges") == "bytes" {
   288  				if fi, err := dst.Stat(); err == nil {
   289  					if _, err = dst.Seek(0, os.SEEK_END); err == nil {
   290  						req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))
   291  
   292  						d.current = uint64(fi.Size())
   293  					}
   294  				}
   295  			}
   296  		} else {
   297  			log.Printf("[DEBUG] (download) Unexpected HTTP response during HEAD request: %s", resp.Status)
   298  		}
   299  	}
   300  
   301  	// Set the request to GET now, and redo the query to download
   302  	req.Method = "GET"
   303  
   304  	resp, err = httpClient.Do(req)
   305  	if err != nil {
   306  		return err
   307  	} else {
   308  		if resp.StatusCode >= 400 && resp.StatusCode < 600 {
   309  			return fmt.Errorf("Error making HTTP GET request: %s", resp.Status)
   310  		}
   311  	}
   312  
   313  	d.total = d.current + uint64(resp.ContentLength)
   314  
   315  	var buffer [4096]byte
   316  	for {
   317  		n, err := resp.Body.Read(buffer[:])
   318  		if err != nil && err != io.EOF {
   319  			return err
   320  		}
   321  
   322  		d.current += uint64(n)
   323  
   324  		if _, werr := dst.Write(buffer[:n]); werr != nil {
   325  			return werr
   326  		}
   327  
   328  		if err == io.EOF {
   329  			break
   330  		}
   331  	}
   332  	return nil
   333  }
   334  
   335  func (d *HTTPDownloader) Progress() uint64 {
   336  	return d.current
   337  }
   338  
   339  func (d *HTTPDownloader) Total() uint64 {
   340  	return d.total
   341  }
   342  
   343  // FileDownloader is an implementation of Downloader that downloads
   344  // files using the regular filesystem.
   345  type FileDownloader struct {
   346  	bufferSize *uint
   347  
   348  	active  bool
   349  	current uint64
   350  	total   uint64
   351  }
   352  
   353  func (d *FileDownloader) Progress() uint64 {
   354  	return d.current
   355  }
   356  
   357  func (d *FileDownloader) Total() uint64 {
   358  	return d.total
   359  }
   360  
   361  func (d *FileDownloader) Cancel() {
   362  	d.active = false
   363  }
   364  
   365  func (d *FileDownloader) Resume() {
   366  	// TODO: Implement
   367  }
   368  
   369  func (d *FileDownloader) toPath(base string, uri url.URL) (string, error) {
   370  	var result string
   371  
   372  	// absolute path -- file://c:/absolute/path -> c:/absolute/path
   373  	if strings.HasSuffix(uri.Host, ":") {
   374  		result = path.Join(uri.Host, uri.Path)
   375  
   376  		// semi-absolute path (current drive letter)
   377  		//	-- file:///absolute/path -> drive:/absolute/path
   378  	} else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") {
   379  		apath := uri.Path
   380  		components := strings.Split(apath, "/")
   381  		volume := filepath.VolumeName(base)
   382  
   383  		// semi-absolute absolute path (includes volume letter)
   384  		// -- file://drive:/path -> drive:/absolute/path
   385  		if len(components) > 1 && strings.HasSuffix(components[1], ":") {
   386  			volume = components[1]
   387  			apath = path.Join(components[2:]...)
   388  		}
   389  
   390  		result = path.Join(volume, apath)
   391  
   392  		// relative path -- file://./relative/path -> ./relative/path
   393  	} else if uri.Host == "." {
   394  		result = path.Join(base, uri.Path)
   395  
   396  		// relative path -- file://relative/path -> ./relative/path
   397  	} else {
   398  		result = path.Join(base, uri.Host, uri.Path)
   399  	}
   400  	return filepath.ToSlash(result), nil
   401  }
   402  
   403  func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
   404  	d.active = false
   405  
   406  	/* check the uri's scheme to make sure it matches */
   407  	if src == nil || src.Scheme != "file" {
   408  		return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
   409  	}
   410  	uri := src
   411  
   412  	/* use the current working directory as the base for relative uri's */
   413  	cwd, err := os.Getwd()
   414  	if err != nil {
   415  		return err
   416  	}
   417  
   418  	/* determine which uri format is being used and convert to a real path */
   419  	realpath, err := d.toPath(cwd, *uri)
   420  	if err != nil {
   421  		return err
   422  	}
   423  
   424  	/* download the file using the operating system's facilities */
   425  	d.current = 0
   426  	d.active = true
   427  
   428  	f, err := os.Open(realpath)
   429  	if err != nil {
   430  		return err
   431  	}
   432  	defer f.Close()
   433  
   434  	// get the file size
   435  	fi, err := f.Stat()
   436  	if err != nil {
   437  		return err
   438  	}
   439  	d.total = uint64(fi.Size())
   440  
   441  	// no bufferSize specified, so copy synchronously.
   442  	if d.bufferSize == nil {
   443  		var n int64
   444  		n, err = io.Copy(dst, f)
   445  		d.active = false
   446  
   447  		d.current += uint64(n)
   448  
   449  		// use a goro in case someone else wants to enable cancel/resume
   450  	} else {
   451  		errch := make(chan error)
   452  		go func(d *FileDownloader, r io.Reader, w io.Writer, e chan error) {
   453  			for d.active {
   454  				n, err := io.CopyN(w, r, int64(*d.bufferSize))
   455  				if err != nil {
   456  					break
   457  				}
   458  
   459  				d.current += uint64(n)
   460  			}
   461  			d.active = false
   462  			e <- err
   463  		}(d, f, dst, errch)
   464  
   465  		// ...and we spin until it's done
   466  		err = <-errch
   467  	}
   468  	f.Close()
   469  	return err
   470  }
   471  
   472  // SMBDownloader is an implementation of Downloader that downloads
   473  // files using the "\\" path format on Windows
   474  type SMBDownloader struct {
   475  	bufferSize *uint
   476  
   477  	active  bool
   478  	current uint64
   479  	total   uint64
   480  }
   481  
   482  func (d *SMBDownloader) Progress() uint64 {
   483  	return d.current
   484  }
   485  
   486  func (d *SMBDownloader) Total() uint64 {
   487  	return d.total
   488  }
   489  
   490  func (d *SMBDownloader) Cancel() {
   491  	d.active = false
   492  }
   493  
   494  func (d *SMBDownloader) Resume() {
   495  	// TODO: Implement
   496  }
   497  
   498  func (d *SMBDownloader) toPath(base string, uri url.URL) (string, error) {
   499  	const UNCPrefix = string(os.PathSeparator) + string(os.PathSeparator)
   500  
   501  	if runtime.GOOS != "windows" {
   502  		return "", fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
   503  	}
   504  
   505  	return UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path)), nil
   506  }
   507  
   508  func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
   509  
   510  	/* first we warn the world if we're not running windows */
   511  	if runtime.GOOS != "windows" {
   512  		return fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
   513  	}
   514  
   515  	d.active = false
   516  
   517  	/* convert the uri using the net/url module to a UNC path */
   518  	if src == nil || src.Scheme != "smb" {
   519  		return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
   520  	}
   521  	uri := src
   522  
   523  	/* use the current working directory as the base for relative uri's */
   524  	cwd, err := os.Getwd()
   525  	if err != nil {
   526  		return err
   527  	}
   528  
   529  	/* convert uri to an smb-path */
   530  	realpath, err := d.toPath(cwd, *uri)
   531  	if err != nil {
   532  		return err
   533  	}
   534  
   535  	/* Open up the "\\"-prefixed path using the Windows filesystem */
   536  	d.current = 0
   537  	d.active = true
   538  
   539  	f, err := os.Open(realpath)
   540  	if err != nil {
   541  		return err
   542  	}
   543  	defer f.Close()
   544  
   545  	// get the file size (at the risk of performance)
   546  	fi, err := f.Stat()
   547  	if err != nil {
   548  		return err
   549  	}
   550  	d.total = uint64(fi.Size())
   551  
   552  	// no bufferSize specified, so copy synchronously.
   553  	if d.bufferSize == nil {
   554  		var n int64
   555  		n, err = io.Copy(dst, f)
   556  		d.active = false
   557  
   558  		d.current += uint64(n)
   559  
   560  		// use a goro in case someone else wants to enable cancel/resume
   561  	} else {
   562  		errch := make(chan error)
   563  		go func(d *SMBDownloader, r io.Reader, w io.Writer, e chan error) {
   564  			for d.active {
   565  				n, err := io.CopyN(w, r, int64(*d.bufferSize))
   566  				if err != nil {
   567  					break
   568  				}
   569  
   570  				d.current += uint64(n)
   571  			}
   572  			d.active = false
   573  			e <- err
   574  		}(d, f, dst, errch)
   575  
   576  		// ...and as usual we spin until it's done
   577  		err = <-errch
   578  	}
   579  	f.Close()
   580  	return err
   581  }