github.com/sneal/packer@v0.5.2/common/download.go (about) 1 package common 2 3 import ( 4 "bytes" 5 "crypto/md5" 6 "crypto/sha1" 7 "crypto/sha256" 8 "crypto/sha512" 9 "encoding/hex" 10 "errors" 11 "fmt" 12 "hash" 13 "io" 14 "log" 15 "net/http" 16 "net/url" 17 "os" 18 "runtime" 19 ) 20 21 // DownloadConfig is the configuration given to instantiate a new 22 // download instance. Once a configuration is used to instantiate 23 // a download client, it must not be modified. 24 type DownloadConfig struct { 25 // The source URL in the form of a string. 26 Url string 27 28 // This is the path to download the file to. 29 TargetPath string 30 31 // DownloaderMap maps a schema to a Download. 32 DownloaderMap map[string]Downloader 33 34 // If true, this will copy even a local file to the target 35 // location. If false, then it will "download" the file by just 36 // returning the local path to the file. 37 CopyFile bool 38 39 // The hashing implementation to use to checksum the downloaded file. 40 Hash hash.Hash 41 42 // The checksum for the downloaded file. The hash implementation configuration 43 // for the downloader will be used to verify with this checksum after 44 // it is downloaded. 45 Checksum []byte 46 47 // What to use for the user agent for HTTP requests. If set to "", use the 48 // default user agent provided by Go. 49 UserAgent string 50 } 51 52 // A DownloadClient helps download, verify checksums, etc. 53 type DownloadClient struct { 54 config *DownloadConfig 55 downloader Downloader 56 } 57 58 // HashForType returns the Hash implementation for the given string 59 // type, or nil if the type is not supported. 60 func HashForType(t string) hash.Hash { 61 switch t { 62 case "md5": 63 return md5.New() 64 case "sha1": 65 return sha1.New() 66 case "sha256": 67 return sha256.New() 68 case "sha512": 69 return sha512.New() 70 default: 71 return nil 72 } 73 } 74 75 // NewDownloadClient returns a new DownloadClient for the given 76 // configuration. 77 func NewDownloadClient(c *DownloadConfig) *DownloadClient { 78 if c.DownloaderMap == nil { 79 c.DownloaderMap = map[string]Downloader{ 80 "http": &HTTPDownloader{userAgent: c.UserAgent}, 81 "https": &HTTPDownloader{userAgent: c.UserAgent}, 82 } 83 } 84 85 return &DownloadClient{config: c} 86 } 87 88 // A downloader is responsible for actually taking a remote URL and 89 // downloading it. 90 type Downloader interface { 91 Cancel() 92 Download(io.Writer, *url.URL) error 93 Progress() uint 94 Total() uint 95 } 96 97 func (d *DownloadClient) Cancel() { 98 // TODO(mitchellh): Implement 99 } 100 101 func (d *DownloadClient) Get() (string, error) { 102 // If we already have the file and it matches, then just return the target path. 103 if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify { 104 log.Println("Initial checksum matched, no download needed.") 105 return d.config.TargetPath, nil 106 } 107 108 url, err := url.Parse(d.config.Url) 109 if err != nil { 110 return "", err 111 } 112 113 log.Printf("Parsed URL: %#v", url) 114 115 // Files when we don't copy the file are special cased. 116 var finalPath string 117 if url.Scheme == "file" && !d.config.CopyFile { 118 finalPath = url.Path 119 120 // Remove forward slash on absolute Windows file URLs before processing 121 if runtime.GOOS == "windows" && finalPath[0] == '/' { 122 finalPath = finalPath[1:len(finalPath)] 123 } 124 } else { 125 finalPath = d.config.TargetPath 126 127 var ok bool 128 d.downloader, ok = d.config.DownloaderMap[url.Scheme] 129 if !ok { 130 return "", fmt.Errorf("No downloader for scheme: %s", url.Scheme) 131 } 132 133 // Otherwise, download using the downloader. 134 f, err := os.Create(finalPath) 135 if err != nil { 136 return "", err 137 } 138 defer f.Close() 139 140 log.Printf("Downloading: %s", url.String()) 141 err = d.downloader.Download(f, url) 142 if err != nil { 143 return "", err 144 } 145 } 146 147 if d.config.Hash != nil { 148 var verify bool 149 verify, err = d.VerifyChecksum(finalPath) 150 if err == nil && !verify { 151 err = fmt.Errorf("checksums didn't match expected: %s", hex.EncodeToString(d.config.Checksum)) 152 } 153 } 154 155 return finalPath, err 156 } 157 158 // PercentProgress returns the download progress as a percentage. 159 func (d *DownloadClient) PercentProgress() int { 160 if d.downloader == nil { 161 return -1 162 } 163 164 return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100) 165 } 166 167 // VerifyChecksum tests that the path matches the checksum for the 168 // download. 169 func (d *DownloadClient) VerifyChecksum(path string) (bool, error) { 170 if d.config.Checksum == nil || d.config.Hash == nil { 171 return false, errors.New("Checksum or Hash isn't set on download.") 172 } 173 174 f, err := os.Open(path) 175 if err != nil { 176 return false, err 177 } 178 defer f.Close() 179 180 log.Printf("Verifying checksum of %s", path) 181 d.config.Hash.Reset() 182 io.Copy(d.config.Hash, f) 183 return bytes.Compare(d.config.Hash.Sum(nil), d.config.Checksum) == 0, nil 184 } 185 186 // HTTPDownloader is an implementation of Downloader that downloads 187 // files over HTTP. 188 type HTTPDownloader struct { 189 progress uint 190 total uint 191 userAgent string 192 } 193 194 func (*HTTPDownloader) Cancel() { 195 // TODO(mitchellh): Implement 196 } 197 198 func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error { 199 log.Printf("Starting download: %s", src.String()) 200 req, err := http.NewRequest("GET", src.String(), nil) 201 if err != nil { 202 return err 203 } 204 205 if d.userAgent != "" { 206 req.Header.Set("User-Agent", d.userAgent) 207 } 208 209 httpClient := &http.Client{ 210 Transport: &http.Transport{ 211 Proxy: http.ProxyFromEnvironment, 212 }, 213 } 214 215 resp, err := httpClient.Do(req) 216 if err != nil { 217 return err 218 } 219 220 if resp.StatusCode != 200 { 221 log.Printf( 222 "Non-200 status code: %d. Getting error body.", resp.StatusCode) 223 224 errorBody := new(bytes.Buffer) 225 io.Copy(errorBody, resp.Body) 226 return fmt.Errorf("HTTP error '%d'! Remote side responded:\n%s", 227 resp.StatusCode, errorBody.String()) 228 } 229 230 d.progress = 0 231 d.total = uint(resp.ContentLength) 232 233 var buffer [4096]byte 234 for { 235 n, err := resp.Body.Read(buffer[:]) 236 if err != nil && err != io.EOF { 237 return err 238 } 239 240 d.progress += uint(n) 241 242 if _, werr := dst.Write(buffer[:n]); werr != nil { 243 return werr 244 } 245 246 if err == io.EOF { 247 break 248 } 249 } 250 251 return nil 252 } 253 254 func (d *HTTPDownloader) Progress() uint { 255 return d.progress 256 } 257 258 func (d *HTTPDownloader) Total() uint { 259 return d.total 260 }