github.com/koron/hk@v0.0.0-20150303213137-b8aeaa3ab34c/update.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "crypto/sha256" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "log" 13 "math/rand" 14 "net/http" 15 "os" 16 "os/exec" 17 "runtime" 18 "time" 19 20 "github.com/heroku/hk/Godeps/_workspace/src/bitbucket.org/kardianos/osext" 21 "github.com/heroku/hk/Godeps/_workspace/src/github.com/inconshreveable/go-update" 22 "github.com/heroku/hk/Godeps/_workspace/src/github.com/kr/binarydist" 23 ) 24 25 var cmdUpdate = &Command{ 26 Run: runUpdate, 27 Usage: "update", 28 Category: "hk", 29 Long: ` 30 Update downloads and installs the next version of hk. 31 32 This command is unlisted, since users never have to run it directly. 33 `, 34 } 35 36 func runUpdate(cmd *Command, args []string) { 37 if updater == nil { 38 printFatal("Dev builds don't support auto-updates") 39 } 40 if err := updater.update(); err != nil { 41 printFatal(err.Error()) 42 } 43 } 44 45 const ( 46 upcktimePath = "cktime" 47 plat = runtime.GOOS + "-" + runtime.GOARCH 48 ) 49 50 var ErrHashMismatch = errors.New("new file hash mismatch after patch") 51 52 // Update protocol. 53 // 54 // GET hk.heroku.com/hk/current/linux-amd64.json 55 // 56 // 200 ok 57 // { 58 // "Version": "2", 59 // "Sha256": "..." // base64 60 // } 61 // 62 // then 63 // 64 // GET hkpatch.s3.amazonaws.com/hk/1/2/linux-amd64 65 // 66 // 200 ok 67 // [bsdiff data] 68 // 69 // or 70 // 71 // GET hkdist.s3.amazonaws.com/hk/2/linux-amd64.gz 72 // 73 // 200 ok 74 // [gzipped executable data] 75 type Updater struct { 76 apiURL string 77 cmdName string 78 binURL string 79 diffURL string 80 dir string 81 info struct { 82 Version string 83 Sha256 []byte 84 } 85 } 86 87 func (u *Updater) backgroundRun() { 88 os.MkdirAll(u.dir, 0777) 89 if u.wantUpdate() { 90 if err := update.SanityCheck(); err != nil { 91 // fail 92 return 93 } 94 self, err := osext.Executable() 95 if err != nil { 96 // fail update, couldn't figure out path to self 97 return 98 } 99 // TODO(bgentry): logger isn't on Windows. Replace w/ proper error reports. 100 l := exec.Command("logger", "-thk") 101 c := exec.Command(self, "update") 102 if w, err := l.StdinPipe(); err == nil && l.Start() == nil { 103 c.Stdout = w 104 c.Stderr = w 105 } 106 c.Start() 107 } 108 } 109 110 func (u *Updater) wantUpdate() bool { 111 path := u.dir + upcktimePath 112 if Version == "dev" || readTime(path).After(time.Now()) { 113 return false 114 } 115 wait := 12*time.Hour + randDuration(8*time.Hour) 116 return writeTime(path, time.Now().Add(wait)) 117 } 118 119 func (u *Updater) update() error { 120 path, err := osext.Executable() 121 if err != nil { 122 return err 123 } 124 old, err := os.Open(path) 125 if err != nil { 126 return err 127 } 128 defer old.Close() 129 130 err = u.fetchInfo() 131 if err != nil { 132 return err 133 } 134 if u.info.Version == Version { 135 return nil 136 } 137 bin, err := u.fetchAndVerifyPatch(old) 138 if err != nil { 139 switch err { 140 case ErrNoPatchAvailable: 141 log.Println("update: no patch available, falling back to full binary") 142 case ErrHashMismatch: 143 log.Println("update: hash mismatch from patched binary") 144 default: 145 log.Println("update: patching binary,", err) 146 } 147 bin, err = u.fetchAndVerifyFullBin() 148 if err != nil { 149 if err == ErrHashMismatch { 150 log.Println("update: hash mismatch from full binary") 151 } else { 152 log.Println("update: fetching full binary,", err) 153 } 154 return err 155 } 156 } 157 158 // close the old binary before installing because on windows 159 // it can't be renamed if a handle to the file is still open 160 old.Close() 161 162 err, errRecover := update.FromStream(bytes.NewBuffer(bin)) 163 if errRecover != nil { 164 return fmt.Errorf("update and recovery errors: %q %q", err, errRecover) 165 } 166 if err != nil { 167 return err 168 } 169 log.Printf("Updated v%s -> v%s.", Version, u.info.Version) 170 return nil 171 } 172 173 func (u *Updater) fetchInfo() error { 174 r, err := fetch(u.apiURL + u.cmdName + "/current/" + plat + ".json") 175 if err != nil { 176 return err 177 } 178 defer r.Close() 179 err = json.NewDecoder(r).Decode(&u.info) 180 if err != nil { 181 return err 182 } 183 if len(u.info.Sha256) != sha256.Size { 184 return errors.New("bad cmd hash in info") 185 } 186 return nil 187 } 188 189 func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) { 190 bin, err := u.fetchAndApplyPatch(old) 191 if err != nil { 192 return nil, err 193 } 194 if !verifySha(bin, u.info.Sha256) { 195 return nil, ErrHashMismatch 196 } 197 return bin, nil 198 } 199 200 func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) { 201 r, err := fetch(u.diffURL + u.cmdName + "/" + Version + "/" + u.info.Version + "/" + plat) 202 if err != nil { 203 return nil, err 204 } 205 defer r.Close() 206 var buf bytes.Buffer 207 err = binarydist.Patch(old, &buf, r) 208 return buf.Bytes(), err 209 } 210 211 func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) { 212 bin, err := u.fetchBin() 213 if err != nil { 214 return nil, err 215 } 216 verified := verifySha(bin, u.info.Sha256) 217 if !verified { 218 return nil, ErrHashMismatch 219 } 220 return bin, nil 221 } 222 223 func (u *Updater) fetchBin() ([]byte, error) { 224 r, err := fetch(u.binURL + u.cmdName + "/" + u.info.Version + "/" + plat + ".gz") 225 if err != nil { 226 return nil, err 227 } 228 defer r.Close() 229 buf := new(bytes.Buffer) 230 gz, err := gzip.NewReader(r) 231 if err != nil { 232 return nil, err 233 } 234 if _, err = io.Copy(buf, gz); err != nil { 235 return nil, err 236 } 237 238 return buf.Bytes(), nil 239 } 240 241 // returns a random duration in [0,n). 242 func randDuration(n time.Duration) time.Duration { 243 return time.Duration(rand.Int63n(int64(n))) 244 } 245 246 var ErrNoPatchAvailable = errors.New("no patch available") 247 248 func fetch(url string) (io.ReadCloser, error) { 249 resp, err := http.Get(url) 250 if err != nil { 251 return nil, err 252 } 253 switch resp.StatusCode { 254 case 200: 255 return resp.Body, nil 256 case 401, 403, 404: 257 return nil, ErrNoPatchAvailable 258 default: 259 return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status) 260 } 261 } 262 263 func readTime(path string) time.Time { 264 p, err := ioutil.ReadFile(path) 265 if os.IsNotExist(err) { 266 return time.Time{} 267 } 268 if err != nil { 269 return time.Now().Add(1000 * time.Hour) 270 } 271 t, err := time.Parse(time.RFC3339, string(p)) 272 if err != nil { 273 return time.Now().Add(1000 * time.Hour) 274 } 275 return t 276 } 277 278 func verifySha(bin []byte, sha []byte) bool { 279 h := sha256.New() 280 h.Write(bin) 281 return bytes.Equal(h.Sum(nil), sha) 282 } 283 284 func writeTime(path string, t time.Time) bool { 285 return ioutil.WriteFile(path, []byte(t.Format(time.RFC3339)), 0644) == nil 286 }