github.com/aspring/packer@v0.8.1-0.20150629211158-9db281ac0f89/common/download.go (about) 1 package common 2 3 import ( 4 "bytes" 5 "crypto/md5" 6 "crypto/sha1" 7 "crypto/sha256" 8 "crypto/sha512" 9 "encoding/hex" 10 "errors" 11 "fmt" 12 "hash" 13 "io" 14 "log" 15 "net/http" 16 "net/url" 17 "os" 18 "runtime" 19 ) 20 21 // DownloadConfig is the configuration given to instantiate a new 22 // download instance. Once a configuration is used to instantiate 23 // a download client, it must not be modified. 24 type DownloadConfig struct { 25 // The source URL in the form of a string. 26 Url string 27 28 // This is the path to download the file to. 29 TargetPath string 30 31 // DownloaderMap maps a schema to a Download. 32 DownloaderMap map[string]Downloader 33 34 // If true, this will copy even a local file to the target 35 // location. If false, then it will "download" the file by just 36 // returning the local path to the file. 37 CopyFile bool 38 39 // The hashing implementation to use to checksum the downloaded file. 40 Hash hash.Hash 41 42 // The checksum for the downloaded file. The hash implementation configuration 43 // for the downloader will be used to verify with this checksum after 44 // it is downloaded. 45 Checksum []byte 46 47 // What to use for the user agent for HTTP requests. If set to "", use the 48 // default user agent provided by Go. 49 UserAgent string 50 } 51 52 // A DownloadClient helps download, verify checksums, etc. 53 type DownloadClient struct { 54 config *DownloadConfig 55 downloader Downloader 56 } 57 58 // HashForType returns the Hash implementation for the given string 59 // type, or nil if the type is not supported. 60 func HashForType(t string) hash.Hash { 61 switch t { 62 case "md5": 63 return md5.New() 64 case "sha1": 65 return sha1.New() 66 case "sha256": 67 return sha256.New() 68 case "sha512": 69 return sha512.New() 70 default: 71 return nil 72 } 73 } 74 75 // NewDownloadClient returns a new DownloadClient for the given 76 // configuration. 77 func NewDownloadClient(c *DownloadConfig) *DownloadClient { 78 if c.DownloaderMap == nil { 79 c.DownloaderMap = map[string]Downloader{ 80 "http": &HTTPDownloader{userAgent: c.UserAgent}, 81 "https": &HTTPDownloader{userAgent: c.UserAgent}, 82 } 83 } 84 85 return &DownloadClient{config: c} 86 } 87 88 // A downloader is responsible for actually taking a remote URL and 89 // downloading it. 90 type Downloader interface { 91 Cancel() 92 Download(*os.File, *url.URL) error 93 Progress() uint 94 Total() uint 95 } 96 97 func (d *DownloadClient) Cancel() { 98 // TODO(mitchellh): Implement 99 } 100 101 func (d *DownloadClient) Get() (string, error) { 102 // If we already have the file and it matches, then just return the target path. 103 if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify { 104 log.Println("Initial checksum matched, no download needed.") 105 return d.config.TargetPath, nil 106 } 107 108 url, err := url.Parse(d.config.Url) 109 if err != nil { 110 return "", err 111 } 112 113 log.Printf("Parsed URL: %#v", url) 114 115 // Files when we don't copy the file are special cased. 116 var f *os.File 117 var finalPath string 118 if url.Scheme == "file" && !d.config.CopyFile { 119 finalPath = url.Path 120 121 // Remove forward slash on absolute Windows file URLs before processing 122 if runtime.GOOS == "windows" && len(finalPath) > 0 && finalPath[0] == '/' { 123 finalPath = finalPath[1:len(finalPath)] 124 } 125 } else { 126 finalPath = d.config.TargetPath 127 128 var ok bool 129 d.downloader, ok = d.config.DownloaderMap[url.Scheme] 130 if !ok { 131 return "", fmt.Errorf("No downloader for scheme: %s", url.Scheme) 132 } 133 134 // Otherwise, download using the downloader. 135 f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666)) 136 if err != nil { 137 return "", err 138 } 139 140 log.Printf("Downloading: %s", url.String()) 141 err = d.downloader.Download(f, url) 142 f.Close() 143 if err != nil { 144 return "", err 145 } 146 } 147 148 if d.config.Hash != nil { 149 var verify bool 150 verify, err = d.VerifyChecksum(finalPath) 151 if err == nil && !verify { 152 // Delete the file 153 os.Remove(finalPath) 154 155 err = fmt.Errorf( 156 "checksums didn't match expected: %s", 157 hex.EncodeToString(d.config.Checksum)) 158 } 159 } 160 161 return finalPath, err 162 } 163 164 // PercentProgress returns the download progress as a percentage. 165 func (d *DownloadClient) PercentProgress() int { 166 if d.downloader == nil { 167 return -1 168 } 169 170 return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100) 171 } 172 173 // VerifyChecksum tests that the path matches the checksum for the 174 // download. 175 func (d *DownloadClient) VerifyChecksum(path string) (bool, error) { 176 if d.config.Checksum == nil || d.config.Hash == nil { 177 return false, errors.New("Checksum or Hash isn't set on download.") 178 } 179 180 f, err := os.Open(path) 181 if err != nil { 182 return false, err 183 } 184 defer f.Close() 185 186 log.Printf("Verifying checksum of %s", path) 187 d.config.Hash.Reset() 188 io.Copy(d.config.Hash, f) 189 return bytes.Compare(d.config.Hash.Sum(nil), d.config.Checksum) == 0, nil 190 } 191 192 // HTTPDownloader is an implementation of Downloader that downloads 193 // files over HTTP. 194 type HTTPDownloader struct { 195 progress uint 196 total uint 197 userAgent string 198 } 199 200 func (*HTTPDownloader) Cancel() { 201 // TODO(mitchellh): Implement 202 } 203 204 func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { 205 log.Printf("Starting download: %s", src.String()) 206 207 // Seek to the beginning by default 208 if _, err := dst.Seek(0, 0); err != nil { 209 return err 210 } 211 212 // Reset our progress 213 d.progress = 0 214 215 // Make the request. We first make a HEAD request so we can check 216 // if the server supports range queries. If the server/URL doesn't 217 // support HEAD requests, we just fall back to GET. 218 req, err := http.NewRequest("HEAD", src.String(), nil) 219 if err != nil { 220 return err 221 } 222 223 if d.userAgent != "" { 224 req.Header.Set("User-Agent", d.userAgent) 225 } 226 227 httpClient := &http.Client{ 228 Transport: &http.Transport{ 229 Proxy: http.ProxyFromEnvironment, 230 }, 231 } 232 233 resp, err := httpClient.Do(req) 234 if err == nil && (resp.StatusCode >= 200 && resp.StatusCode < 300) { 235 // If the HEAD request succeeded, then attempt to set the range 236 // query if we can. 237 if resp.Header.Get("Accept-Ranges") == "bytes" { 238 if fi, err := dst.Stat(); err == nil { 239 if _, err = dst.Seek(0, os.SEEK_END); err == nil { 240 req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size())) 241 d.progress = uint(fi.Size()) 242 } 243 } 244 } 245 } 246 247 // Set the request to GET now, and redo the query to download 248 req.Method = "GET" 249 250 resp, err = httpClient.Do(req) 251 if err != nil { 252 return err 253 } 254 255 d.total = d.progress + uint(resp.ContentLength) 256 var buffer [4096]byte 257 for { 258 n, err := resp.Body.Read(buffer[:]) 259 if err != nil && err != io.EOF { 260 return err 261 } 262 263 d.progress += uint(n) 264 265 if _, werr := dst.Write(buffer[:n]); werr != nil { 266 return werr 267 } 268 269 if err == io.EOF { 270 break 271 } 272 } 273 274 return nil 275 } 276 277 func (d *HTTPDownloader) Progress() uint { 278 return d.progress 279 } 280 281 func (d *HTTPDownloader) Total() uint { 282 return d.total 283 }