github.com/phobos182/packer@v0.2.3-0.20130819023704-c84d2aeffc68/common/download.go (about) 1 package common 2 3 import ( 4 "bytes" 5 "crypto/md5" 6 "crypto/sha1" 7 "crypto/sha256" 8 "encoding/hex" 9 "errors" 10 "fmt" 11 "hash" 12 "io" 13 "log" 14 "net/http" 15 "net/url" 16 "os" 17 "runtime" 18 ) 19 20 // DownloadConfig is the configuration given to instantiate a new 21 // download instance. Once a configuration is used to instantiate 22 // a download client, it must not be modified. 23 type DownloadConfig struct { 24 // The source URL in the form of a string. 25 Url string 26 27 // This is the path to download the file to. 28 TargetPath string 29 30 // DownloaderMap maps a schema to a Download. 31 DownloaderMap map[string]Downloader 32 33 // If true, this will copy even a local file to the target 34 // location. If false, then it will "download" the file by just 35 // returning the local path to the file. 36 CopyFile bool 37 38 // The hashing implementation to use to checksum the downloaded file. 39 Hash hash.Hash 40 41 // The checksum for the downloaded file. The hash implementation configuration 42 // for the downloader will be used to verify with this checksum after 43 // it is downloaded. 44 Checksum []byte 45 } 46 47 // A DownloadClient helps download, verify checksums, etc. 48 type DownloadClient struct { 49 config *DownloadConfig 50 downloader Downloader 51 } 52 53 // HashForType returns the Hash implementation for the given string 54 // type, or nil if the type is not supported. 55 func HashForType(t string) hash.Hash { 56 switch t { 57 case "md5": 58 return md5.New() 59 case "sha1": 60 return sha1.New() 61 case "sha256": 62 return sha256.New() 63 default: 64 return nil 65 } 66 } 67 68 // NewDownloadClient returns a new DownloadClient for the given 69 // configuration. 70 func NewDownloadClient(c *DownloadConfig) *DownloadClient { 71 if c.DownloaderMap == nil { 72 c.DownloaderMap = map[string]Downloader{ 73 "http": new(HTTPDownloader), 74 } 75 } 76 77 return &DownloadClient{config: c} 78 } 79 80 // A downloader is responsible for actually taking a remote URL and 81 // downloading it. 82 type Downloader interface { 83 Cancel() 84 Download(io.Writer, *url.URL) error 85 Progress() uint 86 Total() uint 87 } 88 89 func (d *DownloadClient) Cancel() { 90 // TODO(mitchellh): Implement 91 } 92 93 func (d *DownloadClient) Get() (string, error) { 94 // If we already have the file and it matches, then just return the target path. 95 if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify { 96 log.Println("Initial checksum matched, no download needed.") 97 return d.config.TargetPath, nil 98 } 99 100 url, err := url.Parse(d.config.Url) 101 if err != nil { 102 return "", err 103 } 104 105 log.Printf("Parsed URL: %#v", url) 106 107 // Files when we don't copy the file are special cased. 108 var finalPath string 109 if url.Scheme == "file" && !d.config.CopyFile { 110 finalPath = url.Path 111 112 // Remove forward slash on absolute Windows file URLs before processing 113 if runtime.GOOS == "windows" && finalPath[0] == '/' { 114 finalPath = finalPath[1:len(finalPath)] 115 } 116 } else { 117 finalPath = d.config.TargetPath 118 119 var ok bool 120 d.downloader, ok = d.config.DownloaderMap[url.Scheme] 121 if !ok { 122 return "", fmt.Errorf("No downloader for scheme: %s", url.Scheme) 123 } 124 125 // Otherwise, download using the downloader. 126 f, err := os.Create(finalPath) 127 if err != nil { 128 return "", err 129 } 130 defer f.Close() 131 132 log.Printf("Downloading: %s", url.String()) 133 err = d.downloader.Download(f, url) 134 if err != nil { 135 return "", err 136 } 137 } 138 139 if d.config.Hash != nil { 140 var verify bool 141 verify, err = d.VerifyChecksum(finalPath) 142 if err == nil && !verify { 143 err = fmt.Errorf("checksums didn't match expected: %s", hex.EncodeToString(d.config.Checksum)) 144 } 145 } 146 147 return finalPath, err 148 } 149 150 // PercentProgress returns the download progress as a percentage. 151 func (d *DownloadClient) PercentProgress() int { 152 if d.downloader == nil { 153 return -1 154 } 155 156 return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100) 157 } 158 159 // VerifyChecksum tests that the path matches the checksum for the 160 // download. 161 func (d *DownloadClient) VerifyChecksum(path string) (bool, error) { 162 if d.config.Checksum == nil || d.config.Hash == nil { 163 return false, errors.New("Checksum or Hash isn't set on download.") 164 } 165 166 f, err := os.Open(path) 167 if err != nil { 168 return false, err 169 } 170 defer f.Close() 171 172 log.Printf("Verifying checksum of %s", path) 173 d.config.Hash.Reset() 174 io.Copy(d.config.Hash, f) 175 return bytes.Compare(d.config.Hash.Sum(nil), d.config.Checksum) == 0, nil 176 } 177 178 // HTTPDownloader is an implementation of Downloader that downloads 179 // files over HTTP. 180 type HTTPDownloader struct { 181 progress uint 182 total uint 183 } 184 185 func (*HTTPDownloader) Cancel() { 186 // TODO(mitchellh): Implement 187 } 188 189 func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error { 190 log.Printf("Starting download: %s", src.String()) 191 req, err := http.NewRequest("GET", src.String(), nil) 192 if err != nil { 193 return err 194 } 195 196 httpClient := &http.Client{ 197 Transport: &http.Transport{ 198 Proxy: http.ProxyFromEnvironment, 199 }, 200 } 201 202 resp, err := httpClient.Do(req) 203 if err != nil { 204 return err 205 } 206 207 if resp.StatusCode != 200 { 208 log.Printf( 209 "Non-200 status code: %d. Getting error body.", resp.StatusCode) 210 211 errorBody := new(bytes.Buffer) 212 io.Copy(errorBody, resp.Body) 213 return fmt.Errorf("HTTP error '%d'! Remote side responded:\n%s", 214 resp.StatusCode, errorBody.String()) 215 } 216 217 d.progress = 0 218 d.total = uint(resp.ContentLength) 219 220 var buffer [4096]byte 221 for { 222 n, err := resp.Body.Read(buffer[:]) 223 if err != nil && err != io.EOF { 224 return err 225 } 226 227 d.progress += uint(n) 228 229 if _, werr := dst.Write(buffer[:n]); werr != nil { 230 return werr 231 } 232 233 if err == io.EOF { 234 break 235 } 236 } 237 238 return nil 239 } 240 241 func (d *HTTPDownloader) Progress() uint { 242 return d.progress 243 } 244 245 func (d *HTTPDownloader) Total() uint { 246 return d.total 247 }