github.com/koron/hk@v0.0.0-20150303213137-b8aeaa3ab34c/update.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"crypto/sha256"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"log"
    13  	"math/rand"
    14  	"net/http"
    15  	"os"
    16  	"os/exec"
    17  	"runtime"
    18  	"time"
    19  
    20  	"github.com/heroku/hk/Godeps/_workspace/src/bitbucket.org/kardianos/osext"
    21  	"github.com/heroku/hk/Godeps/_workspace/src/github.com/inconshreveable/go-update"
    22  	"github.com/heroku/hk/Godeps/_workspace/src/github.com/kr/binarydist"
    23  )
    24  
    25  var cmdUpdate = &Command{
    26  	Run:      runUpdate,
    27  	Usage:    "update",
    28  	Category: "hk",
    29  	Long: `
    30  Update downloads and installs the next version of hk.
    31  
    32  This command is unlisted, since users never have to run it directly.
    33  `,
    34  }
    35  
    36  func runUpdate(cmd *Command, args []string) {
    37  	if updater == nil {
    38  		printFatal("Dev builds don't support auto-updates")
    39  	}
    40  	if err := updater.update(); err != nil {
    41  		printFatal(err.Error())
    42  	}
    43  }
    44  
    45  const (
    46  	upcktimePath = "cktime"
    47  	plat         = runtime.GOOS + "-" + runtime.GOARCH
    48  )
    49  
    50  var ErrHashMismatch = errors.New("new file hash mismatch after patch")
    51  
    52  // Update protocol.
    53  //
    54  //   GET hk.heroku.com/hk/current/linux-amd64.json
    55  //
    56  //   200 ok
    57  //   {
    58  //       "Version": "2",
    59  //       "Sha256": "..." // base64
    60  //   }
    61  //
    62  // then
    63  //
    64  //   GET hkpatch.s3.amazonaws.com/hk/1/2/linux-amd64
    65  //
    66  //   200 ok
    67  //   [bsdiff data]
    68  //
    69  // or
    70  //
    71  //   GET hkdist.s3.amazonaws.com/hk/2/linux-amd64.gz
    72  //
    73  //   200 ok
    74  //   [gzipped executable data]
    75  type Updater struct {
    76  	apiURL  string
    77  	cmdName string
    78  	binURL  string
    79  	diffURL string
    80  	dir     string
    81  	info    struct {
    82  		Version string
    83  		Sha256  []byte
    84  	}
    85  }
    86  
    87  func (u *Updater) backgroundRun() {
    88  	os.MkdirAll(u.dir, 0777)
    89  	if u.wantUpdate() {
    90  		if err := update.SanityCheck(); err != nil {
    91  			// fail
    92  			return
    93  		}
    94  		self, err := osext.Executable()
    95  		if err != nil {
    96  			// fail update, couldn't figure out path to self
    97  			return
    98  		}
    99  		// TODO(bgentry): logger isn't on Windows. Replace w/ proper error reports.
   100  		l := exec.Command("logger", "-thk")
   101  		c := exec.Command(self, "update")
   102  		if w, err := l.StdinPipe(); err == nil && l.Start() == nil {
   103  			c.Stdout = w
   104  			c.Stderr = w
   105  		}
   106  		c.Start()
   107  	}
   108  }
   109  
   110  func (u *Updater) wantUpdate() bool {
   111  	path := u.dir + upcktimePath
   112  	if Version == "dev" || readTime(path).After(time.Now()) {
   113  		return false
   114  	}
   115  	wait := 12*time.Hour + randDuration(8*time.Hour)
   116  	return writeTime(path, time.Now().Add(wait))
   117  }
   118  
   119  func (u *Updater) update() error {
   120  	path, err := osext.Executable()
   121  	if err != nil {
   122  		return err
   123  	}
   124  	old, err := os.Open(path)
   125  	if err != nil {
   126  		return err
   127  	}
   128  	defer old.Close()
   129  
   130  	err = u.fetchInfo()
   131  	if err != nil {
   132  		return err
   133  	}
   134  	if u.info.Version == Version {
   135  		return nil
   136  	}
   137  	bin, err := u.fetchAndVerifyPatch(old)
   138  	if err != nil {
   139  		switch err {
   140  		case ErrNoPatchAvailable:
   141  			log.Println("update: no patch available, falling back to full binary")
   142  		case ErrHashMismatch:
   143  			log.Println("update: hash mismatch from patched binary")
   144  		default:
   145  			log.Println("update: patching binary,", err)
   146  		}
   147  		bin, err = u.fetchAndVerifyFullBin()
   148  		if err != nil {
   149  			if err == ErrHashMismatch {
   150  				log.Println("update: hash mismatch from full binary")
   151  			} else {
   152  				log.Println("update: fetching full binary,", err)
   153  			}
   154  			return err
   155  		}
   156  	}
   157  
   158  	// close the old binary before installing because on windows
   159  	// it can't be renamed if a handle to the file is still open
   160  	old.Close()
   161  
   162  	err, errRecover := update.FromStream(bytes.NewBuffer(bin))
   163  	if errRecover != nil {
   164  		return fmt.Errorf("update and recovery errors: %q %q", err, errRecover)
   165  	}
   166  	if err != nil {
   167  		return err
   168  	}
   169  	log.Printf("Updated v%s -> v%s.", Version, u.info.Version)
   170  	return nil
   171  }
   172  
   173  func (u *Updater) fetchInfo() error {
   174  	r, err := fetch(u.apiURL + u.cmdName + "/current/" + plat + ".json")
   175  	if err != nil {
   176  		return err
   177  	}
   178  	defer r.Close()
   179  	err = json.NewDecoder(r).Decode(&u.info)
   180  	if err != nil {
   181  		return err
   182  	}
   183  	if len(u.info.Sha256) != sha256.Size {
   184  		return errors.New("bad cmd hash in info")
   185  	}
   186  	return nil
   187  }
   188  
   189  func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) {
   190  	bin, err := u.fetchAndApplyPatch(old)
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  	if !verifySha(bin, u.info.Sha256) {
   195  		return nil, ErrHashMismatch
   196  	}
   197  	return bin, nil
   198  }
   199  
   200  func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) {
   201  	r, err := fetch(u.diffURL + u.cmdName + "/" + Version + "/" + u.info.Version + "/" + plat)
   202  	if err != nil {
   203  		return nil, err
   204  	}
   205  	defer r.Close()
   206  	var buf bytes.Buffer
   207  	err = binarydist.Patch(old, &buf, r)
   208  	return buf.Bytes(), err
   209  }
   210  
   211  func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) {
   212  	bin, err := u.fetchBin()
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	verified := verifySha(bin, u.info.Sha256)
   217  	if !verified {
   218  		return nil, ErrHashMismatch
   219  	}
   220  	return bin, nil
   221  }
   222  
   223  func (u *Updater) fetchBin() ([]byte, error) {
   224  	r, err := fetch(u.binURL + u.cmdName + "/" + u.info.Version + "/" + plat + ".gz")
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  	defer r.Close()
   229  	buf := new(bytes.Buffer)
   230  	gz, err := gzip.NewReader(r)
   231  	if err != nil {
   232  		return nil, err
   233  	}
   234  	if _, err = io.Copy(buf, gz); err != nil {
   235  		return nil, err
   236  	}
   237  
   238  	return buf.Bytes(), nil
   239  }
   240  
   241  // returns a random duration in [0,n).
   242  func randDuration(n time.Duration) time.Duration {
   243  	return time.Duration(rand.Int63n(int64(n)))
   244  }
   245  
   246  var ErrNoPatchAvailable = errors.New("no patch available")
   247  
   248  func fetch(url string) (io.ReadCloser, error) {
   249  	resp, err := http.Get(url)
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  	switch resp.StatusCode {
   254  	case 200:
   255  		return resp.Body, nil
   256  	case 401, 403, 404:
   257  		return nil, ErrNoPatchAvailable
   258  	default:
   259  		return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status)
   260  	}
   261  }
   262  
   263  func readTime(path string) time.Time {
   264  	p, err := ioutil.ReadFile(path)
   265  	if os.IsNotExist(err) {
   266  		return time.Time{}
   267  	}
   268  	if err != nil {
   269  		return time.Now().Add(1000 * time.Hour)
   270  	}
   271  	t, err := time.Parse(time.RFC3339, string(p))
   272  	if err != nil {
   273  		return time.Now().Add(1000 * time.Hour)
   274  	}
   275  	return t
   276  }
   277  
   278  func verifySha(bin []byte, sha []byte) bool {
   279  	h := sha256.New()
   280  	h.Write(bin)
   281  	return bytes.Equal(h.Sum(nil), sha)
   282  }
   283  
   284  func writeTime(path string, t time.Time) bool {
   285  	return ioutil.WriteFile(path, []byte(t.Format(time.RFC3339)), 0644) == nil
   286  }