github.com/aclaygray/packer@v1.3.2/common/download.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"crypto/sha1"
     7  	"crypto/sha256"
     8  	"crypto/sha512"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"io"
    14  	"log"
    15  	"net/http"
    16  	"net/url"
    17  	"os"
    18  	"path"
    19  	"path/filepath"
    20  	"runtime"
    21  	"strings"
    22  
    23  	"github.com/hashicorp/packer/packer"
    24  )
    25  
    26  // DownloadConfig is the configuration given to instantiate a new
    27  // download instance. Once a configuration is used to instantiate
    28  // a download client, it must not be modified.
    29  type DownloadConfig struct {
    30  	// The source URL in the form of a string.
    31  	Url string
    32  
    33  	// This is the path to download the file to.
    34  	TargetPath string
    35  
    36  	// DownloaderMap maps a schema to a Download.
    37  	DownloaderMap map[string]Downloader
    38  
    39  	// If true, this will copy even a local file to the target
    40  	// location. If false, then it will "download" the file by just
    41  	// returning the local path to the file.
    42  	CopyFile bool
    43  
    44  	// The hashing implementation to use to checksum the downloaded file.
    45  	Hash hash.Hash
    46  
    47  	// The checksum for the downloaded file. The hash implementation configuration
    48  	// for the downloader will be used to verify with this checksum after
    49  	// it is downloaded.
    50  	Checksum []byte
    51  
    52  	// What to use for the user agent for HTTP requests. If set to "", use the
    53  	// default user agent provided by Go.
    54  	UserAgent string
    55  }
    56  
    57  // A DownloadClient helps download, verify checksums, etc.
    58  type DownloadClient struct {
    59  	config *DownloadConfig
    60  }
    61  
    62  // HashForType returns the Hash implementation for the given string
    63  // type, or nil if the type is not supported.
    64  func HashForType(t string) hash.Hash {
    65  	switch t {
    66  	case "md5":
    67  		return md5.New()
    68  	case "sha1":
    69  		return sha1.New()
    70  	case "sha256":
    71  		return sha256.New()
    72  	case "sha512":
    73  		return sha512.New()
    74  	default:
    75  		return nil
    76  	}
    77  }
    78  
    79  // NewDownloadClient returns a new DownloadClient for the given
    80  // configuration.
    81  func NewDownloadClient(c *DownloadConfig, ui packer.Ui) *DownloadClient {
    82  	// Create downloader map if it hasn't been specified already.
    83  	if c.DownloaderMap == nil {
    84  		c.DownloaderMap = map[string]Downloader{
    85  			"file":  &FileDownloader{Ui: ui, bufferSize: nil},
    86  			"http":  &HTTPDownloader{Ui: ui, userAgent: c.UserAgent},
    87  			"https": &HTTPDownloader{Ui: ui, userAgent: c.UserAgent},
    88  			"smb":   &SMBDownloader{Ui: ui, bufferSize: nil},
    89  		}
    90  	}
    91  	return &DownloadClient{config: c}
    92  }
    93  
    94  // Downloader defines what capabilities a downloader should have.
    95  type Downloader interface {
    96  	Resume()
    97  	Cancel()
    98  	ProgressBar() packer.ProgressBar
    99  }
   100  
   101  // A LocalDownloader is responsible for converting a uri to a local path
   102  //	that the platform can open directly.
   103  type LocalDownloader interface {
   104  	toPath(string, url.URL) (string, error)
   105  }
   106  
   107  // A RemoteDownloader is responsible for actually taking a remote URL and
   108  //	downloading it.
   109  type RemoteDownloader interface {
   110  	Download(*os.File, *url.URL) error
   111  }
   112  
   113  func (d *DownloadClient) Cancel() {
   114  	// TODO(mitchellh): Implement
   115  }
   116  
   117  func (d *DownloadClient) Get() (string, error) {
   118  	// If we already have the file and it matches, then just return the target path.
   119  	if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
   120  		log.Println("[DEBUG] Initial checksum matched, no download needed.")
   121  		return d.config.TargetPath, nil
   122  	}
   123  
   124  	/* parse the configuration url into a net/url object */
   125  	u, err := url.Parse(d.config.Url)
   126  	if err != nil {
   127  		return "", err
   128  	}
   129  	log.Printf("Parsed URL: %#v", u)
   130  
   131  	/* use the current working directory as the base for relative uri's */
   132  	cwd, err := os.Getwd()
   133  	if err != nil {
   134  		return "", err
   135  	}
   136  
   137  	// Determine which is the correct downloader to use
   138  	var finalPath string
   139  
   140  	var ok bool
   141  	downloader, ok := d.config.DownloaderMap[u.Scheme]
   142  	if !ok {
   143  		return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme)
   144  	}
   145  
   146  	remote, ok := downloader.(RemoteDownloader)
   147  	if !ok {
   148  		return "", fmt.Errorf("Unable to treat uri scheme %s as a Downloader. : %T", u.Scheme, downloader)
   149  	}
   150  
   151  	local, ok := downloader.(LocalDownloader)
   152  	if !ok && !d.config.CopyFile {
   153  		d.config.CopyFile = true
   154  	}
   155  
   156  	// If we're copying the file, then just use the actual downloader
   157  	if d.config.CopyFile {
   158  		var f *os.File
   159  		finalPath = d.config.TargetPath
   160  
   161  		f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666))
   162  		if err != nil {
   163  			return "", err
   164  		}
   165  
   166  		log.Printf("[DEBUG] Downloading: %s", u.String())
   167  		err = remote.Download(f, u)
   168  		f.Close()
   169  		if err != nil {
   170  			return "", err
   171  		}
   172  
   173  		// Otherwise if our Downloader is a LocalDownloader we can just use the
   174  		//	path after transforming it.
   175  	} else {
   176  		finalPath, err = local.toPath(cwd, *u)
   177  		if err != nil {
   178  			return "", err
   179  		}
   180  
   181  		log.Printf("[DEBUG] Using local file: %s", finalPath)
   182  	}
   183  
   184  	if d.config.Hash != nil {
   185  		var verify bool
   186  		verify, err = d.VerifyChecksum(finalPath)
   187  		if err == nil && !verify {
   188  			// Only delete the file if we made a copy or downloaded it
   189  			if d.config.CopyFile {
   190  				os.Remove(finalPath)
   191  			}
   192  
   193  			err = fmt.Errorf(
   194  				"checksums didn't match expected: %s",
   195  				hex.EncodeToString(d.config.Checksum))
   196  		}
   197  	}
   198  
   199  	return finalPath, err
   200  }
   201  
   202  // VerifyChecksum tests that the path matches the checksum for the
   203  // download.
   204  func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
   205  	if d.config.Checksum == nil || d.config.Hash == nil {
   206  		return false, errors.New("Checksum or Hash isn't set on download.")
   207  	}
   208  
   209  	f, err := os.Open(path)
   210  	if err != nil {
   211  		return false, err
   212  	}
   213  	defer f.Close()
   214  
   215  	log.Printf("Verifying checksum of %s", path)
   216  	d.config.Hash.Reset()
   217  	io.Copy(d.config.Hash, f)
   218  	return bytes.Equal(d.config.Hash.Sum(nil), d.config.Checksum), nil
   219  }
   220  
   221  // HTTPDownloader is an implementation of Downloader that downloads
   222  // files over HTTP.
   223  type HTTPDownloader struct {
   224  	userAgent string
   225  
   226  	Ui packer.Ui
   227  }
   228  
   229  func (d *HTTPDownloader) Cancel() {
   230  	// TODO(mitchellh): Implement
   231  }
   232  
   233  func (d *HTTPDownloader) Resume() {
   234  	// TODO(mitchellh): Implement
   235  }
   236  
   237  func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
   238  	log.Printf("Starting download over HTTP: %s", src.String())
   239  
   240  	// Seek to the beginning by default
   241  	if _, err := dst.Seek(0, 0); err != nil {
   242  		return err
   243  	}
   244  
   245  	var current int64
   246  
   247  	// Make the request. We first make a HEAD request so we can check
   248  	// if the server supports range queries. If the server/URL doesn't
   249  	// support HEAD requests, we just fall back to GET.
   250  	req, err := http.NewRequest("HEAD", src.String(), nil)
   251  	if err != nil {
   252  		return err
   253  	}
   254  
   255  	if d.userAgent != "" {
   256  		req.Header.Set("User-Agent", d.userAgent)
   257  	}
   258  
   259  	httpClient := &http.Client{
   260  		Transport: &http.Transport{
   261  			Proxy: http.ProxyFromEnvironment,
   262  		},
   263  	}
   264  
   265  	resp, err := httpClient.Do(req)
   266  	if err != nil || resp == nil {
   267  
   268  		if resp == nil {
   269  			log.Printf("[DEBUG] (download) HTTP connection error: %s", err.Error())
   270  
   271  		} else if resp.StatusCode >= 400 && resp.StatusCode < 600 {
   272  			log.Printf("[DEBUG] (download) Non-successful HTTP status code (%s) while making HEAD request: %s", resp.Status, err.Error())
   273  
   274  		} else {
   275  			log.Printf("[DEBUG] (download) Error making HTTP HEAD request (%s): %s", resp.Status, err.Error())
   276  		}
   277  
   278  	} else {
   279  
   280  		if resp.StatusCode >= 200 && resp.StatusCode < 300 {
   281  			// If the HEAD request succeeded, then attempt to set the range
   282  			// query if we can.
   283  
   284  			if resp.Header.Get("Accept-Ranges") == "bytes" {
   285  				if fi, err := dst.Stat(); err == nil {
   286  					if _, err = dst.Seek(0, os.SEEK_END); err == nil {
   287  						req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))
   288  
   289  						current = fi.Size()
   290  					}
   291  				}
   292  			}
   293  
   294  		} else {
   295  			log.Printf("[DEBUG] (download) Unexpected HTTP response during HEAD request: %s", resp.Status)
   296  		}
   297  	}
   298  
   299  	// Set the request to GET now, and redo the query to download
   300  	req.Method = "GET"
   301  
   302  	resp, err = httpClient.Do(req)
   303  	if err == nil && (resp.StatusCode >= 400 && resp.StatusCode < 600) {
   304  		return fmt.Errorf("Error making HTTP GET request: %s", resp.Status)
   305  
   306  	} else if err != nil {
   307  		if resp == nil {
   308  			return fmt.Errorf("HTTP connection error: %s", err.Error())
   309  
   310  		} else if resp.StatusCode >= 400 && resp.StatusCode < 600 {
   311  			return fmt.Errorf("HTTP %s error: %s", resp.Status, err.Error())
   312  		}
   313  		return fmt.Errorf("HTTP error: %s", err.Error())
   314  	}
   315  
   316  	total := current + resp.ContentLength
   317  
   318  	bar := d.ProgressBar()
   319  	bar.Start(total)
   320  	defer bar.Finish()
   321  	bar.Add(current)
   322  
   323  	body := bar.NewProxyReader(resp.Body)
   324  
   325  	var buffer [4096]byte
   326  	for {
   327  		n, err := body.Read(buffer[:])
   328  		if err != nil && err != io.EOF {
   329  			return err
   330  		}
   331  
   332  		if _, werr := dst.Write(buffer[:n]); werr != nil {
   333  			return werr
   334  		}
   335  
   336  		if err == io.EOF {
   337  			break
   338  		}
   339  	}
   340  	return nil
   341  }
   342  
   343  // FileDownloader is an implementation of Downloader that downloads
   344  // files using the regular filesystem.
   345  type FileDownloader struct {
   346  	bufferSize *uint
   347  
   348  	active bool
   349  	Ui     packer.Ui
   350  }
   351  
   352  func (d *FileDownloader) Cancel() {
   353  	d.active = false
   354  }
   355  
   356  func (d *FileDownloader) Resume() {
   357  	// TODO: Implement
   358  }
   359  
   360  func (d *FileDownloader) toPath(base string, uri url.URL) (string, error) {
   361  	var result string
   362  
   363  	// absolute path -- file://c:/absolute/path -> c:/absolute/path
   364  	if strings.HasSuffix(uri.Host, ":") {
   365  		result = path.Join(uri.Host, uri.Path)
   366  
   367  		// semi-absolute path (current drive letter)
   368  		//	-- file:///absolute/path -> drive:/absolute/path
   369  	} else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") {
   370  		apath := uri.Path
   371  		components := strings.Split(apath, "/")
   372  		volume := filepath.VolumeName(base)
   373  
   374  		// semi-absolute absolute path (includes volume letter)
   375  		// -- file://drive:/path -> drive:/absolute/path
   376  		if len(components) > 1 && strings.HasSuffix(components[1], ":") {
   377  			volume = components[1]
   378  			apath = path.Join(components[2:]...)
   379  		}
   380  
   381  		result = path.Join(volume, apath)
   382  
   383  		// relative path -- file://./relative/path -> ./relative/path
   384  	} else if uri.Host == "." {
   385  		result = path.Join(base, uri.Path)
   386  
   387  		// relative path -- file://relative/path -> ./relative/path
   388  	} else {
   389  		result = path.Join(base, uri.Host, uri.Path)
   390  	}
   391  	return filepath.ToSlash(result), nil
   392  }
   393  
   394  func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
   395  	d.active = false
   396  
   397  	/* check the uri's scheme to make sure it matches */
   398  	if src == nil || src.Scheme != "file" {
   399  		return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
   400  	}
   401  	uri := src
   402  
   403  	/* use the current working directory as the base for relative uri's */
   404  	cwd, err := os.Getwd()
   405  	if err != nil {
   406  		return err
   407  	}
   408  
   409  	/* determine which uri format is being used and convert to a real path */
   410  	realpath, err := d.toPath(cwd, *uri)
   411  	if err != nil {
   412  		return err
   413  	}
   414  
   415  	/* download the file using the operating system's facilities */
   416  	d.active = true
   417  
   418  	f, err := os.Open(realpath)
   419  	if err != nil {
   420  		return err
   421  	}
   422  	defer f.Close()
   423  
   424  	// get the file size
   425  	fi, err := f.Stat()
   426  	if err != nil {
   427  		return err
   428  	}
   429  
   430  	bar := d.ProgressBar()
   431  
   432  	bar.Start(fi.Size())
   433  	defer bar.Finish()
   434  	fProxy := bar.NewProxyReader(f)
   435  
   436  	// no bufferSize specified, so copy synchronously.
   437  	if d.bufferSize == nil {
   438  		_, err = io.Copy(dst, fProxy)
   439  		d.active = false
   440  
   441  		// use a goro in case someone else wants to enable cancel/resume
   442  	} else {
   443  		errch := make(chan error)
   444  		go func(d *FileDownloader, r io.Reader, w io.Writer, e chan error) {
   445  			for d.active {
   446  				_, err := io.CopyN(w, r, int64(*d.bufferSize))
   447  				if err != nil {
   448  					break
   449  				}
   450  
   451  			}
   452  			d.active = false
   453  			e <- err
   454  		}(d, fProxy, dst, errch)
   455  
   456  		// ...and we spin until it's done
   457  		err = <-errch
   458  	}
   459  
   460  	return err
   461  }
   462  
   463  // SMBDownloader is an implementation of Downloader that downloads
   464  // files using the "\\" path format on Windows
   465  type SMBDownloader struct {
   466  	bufferSize *uint
   467  
   468  	active bool
   469  	Ui     packer.Ui
   470  }
   471  
   472  func (d *SMBDownloader) Cancel() {
   473  	d.active = false
   474  }
   475  
   476  func (d *SMBDownloader) Resume() {
   477  	// TODO: Implement
   478  }
   479  
   480  func (d *SMBDownloader) toPath(base string, uri url.URL) (string, error) {
   481  	const UNCPrefix = string(os.PathSeparator) + string(os.PathSeparator)
   482  
   483  	if runtime.GOOS != "windows" {
   484  		return "", fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
   485  	}
   486  
   487  	return UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path)), nil
   488  }
   489  
   490  func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
   491  
   492  	/* first we warn the world if we're not running windows */
   493  	if runtime.GOOS != "windows" {
   494  		return fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
   495  	}
   496  
   497  	d.active = false
   498  
   499  	/* convert the uri using the net/url module to a UNC path */
   500  	if src == nil || src.Scheme != "smb" {
   501  		return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
   502  	}
   503  	uri := src
   504  
   505  	/* use the current working directory as the base for relative uri's */
   506  	cwd, err := os.Getwd()
   507  	if err != nil {
   508  		return err
   509  	}
   510  
   511  	/* convert uri to an smb-path */
   512  	realpath, err := d.toPath(cwd, *uri)
   513  	if err != nil {
   514  		return err
   515  	}
   516  
   517  	/* Open up the "\\"-prefixed path using the Windows filesystem */
   518  	d.active = true
   519  
   520  	f, err := os.Open(realpath)
   521  	if err != nil {
   522  		return err
   523  	}
   524  	defer f.Close()
   525  
   526  	// get the file size (at the risk of performance)
   527  	fi, err := f.Stat()
   528  	if err != nil {
   529  		return err
   530  	}
   531  
   532  	bar := d.ProgressBar()
   533  
   534  	bar.Start(fi.Size())
   535  	defer bar.Finish()
   536  	fProxy := bar.NewProxyReader(f)
   537  
   538  	// no bufferSize specified, so copy synchronously.
   539  	if d.bufferSize == nil {
   540  		_, err = io.Copy(dst, fProxy)
   541  		d.active = false
   542  
   543  		// use a goro in case someone else wants to enable cancel/resume
   544  	} else {
   545  		errch := make(chan error)
   546  		go func(d *SMBDownloader, r io.Reader, w io.Writer, e chan error) {
   547  			for d.active {
   548  				_, err := io.CopyN(w, r, int64(*d.bufferSize))
   549  				if err != nil {
   550  					break
   551  				}
   552  
   553  			}
   554  			d.active = false
   555  			e <- err
   556  		}(d, fProxy, dst, errch)
   557  
   558  		// ...and as usual we spin until it's done
   559  		err = <-errch
   560  	}
   561  	return err
   562  }
   563  
   564  func (d *HTTPDownloader) ProgressBar() packer.ProgressBar { return d.Ui.ProgressBar() }
   565  func (d *FileDownloader) ProgressBar() packer.ProgressBar { return d.Ui.ProgressBar() }
   566  func (d *SMBDownloader) ProgressBar() packer.ProgressBar  { return d.Ui.ProgressBar() }