github.com/koron/hk@v0.0.0-20150303213137-b8aeaa3ab34c/hkdist/web.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"database/sql"
     7  	"encoding/json"
     8  	"io"
     9  	"log"
    10  	"net/http"
    11  	"os"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/heroku/hk/Godeps/_workspace/src/github.com/bmizerany/pq"
    16  	"github.com/heroku/hk/Godeps/_workspace/src/github.com/gorilla/mux"
    17  	"github.com/heroku/hk/Godeps/_workspace/src/github.com/kr/secureheader"
    18  )
    19  
    20  const (
    21  	pgUniqueViolation = "23505"
    22  )
    23  
    24  var db *sql.DB
    25  
    26  // Examples:
    27  //
    28  //   PUT /hk-1-linux-386.json
    29  //   PUT /hk-linux-386.json
    30  //
    31  //   GET /hk-current-linux-386.json
    32  //   GET /hk-1-linux-386.json
    33  //   GET /hk.gz
    34  func web(args []string) {
    35  	mustHaveEnv("DATABASE_URL")
    36  	initwebdb()
    37  	r := mux.NewRouter()
    38  	r.HandleFunc("/{cmd}.gz", http.HandlerFunc(initial)).Methods("GET", "HEAD")
    39  	r.HandleFunc("/{cmd}/current/{plat}.json", http.HandlerFunc(curInfo)).Methods("GET", "HEAD")
    40  	r.HandleFunc("/{cmd}/{ver}/{plat}.json", http.HandlerFunc(getHash)).Methods("GET", "HEAD")
    41  	r.HandleFunc("/release.json", http.HandlerFunc(listReleases)).Methods("GET", "HEAD")
    42  	r.Path("/{cmd}/current/{plat}.json").Methods("PUT").Handler(authenticate{herokaiOnly{http.HandlerFunc(setCur)}})
    43  	r.Path("/{cmd}/{ver}/{plat}.json").Methods("PUT").Handler(authenticate{herokaiOnly{http.HandlerFunc(putVer)}})
    44  	r.PathPrefix("/").Methods("GET", "HEAD").Handler(http.FileServer(http.Dir("hkdist/public")))
    45  	http.Handle("/", r)
    46  	secureheader.DefaultConfig.PermitClearLoopback = true
    47  	err := http.ListenAndServe(":"+os.Getenv("PORT"), secureheader.DefaultConfig)
    48  	if err != nil {
    49  		log.Fatalf(`{"func":"ListenAndServe", "error":%q}`, err)
    50  	}
    51  }
    52  
    53  func setCur(w http.ResponseWriter, r *http.Request) {
    54  	defer r.Body.Close()
    55  	q := mux.Vars(r)
    56  	plat := q["plat"]
    57  	cmd := q["cmd"]
    58  	if strings.IndexFunc(plat, badIdentRune) >= 0 ||
    59  		strings.IndexFunc(cmd, badIdentRune) >= 0 {
    60  		http.Error(w, "bad character in path", 400)
    61  		return
    62  	}
    63  
    64  	var info struct{ Version string }
    65  	if !readReqJSON(w, r, 1000, &info) {
    66  		return
    67  	}
    68  	_, err := db.Exec(`
    69  		update cur set curver=$1 where plat=$2 and cmd=$3
    70  	`, info.Version, plat, cmd)
    71  	if err != nil {
    72  		log.Println(err)
    73  		http.Error(w, "internal error", 500)
    74  		return
    75  	}
    76  	_, err = db.Exec(`
    77  		insert into cur (plat, cmd, curver)
    78  		select $1, $2, $3
    79  		where not exists (select 1 from cur where plat=$1 and cmd=$2)
    80  	`, plat, cmd, info.Version)
    81  	if err != nil {
    82  		log.Println(err)
    83  		http.Error(w, "internal error", 500)
    84  		return
    85  	}
    86  	if _, err = db.Exec(`update mod set t=now()`); err != nil {
    87  		log.Println(err)
    88  		http.Error(w, "internal error", 500)
    89  		return
    90  	}
    91  	io.WriteString(w, "ok\n")
    92  }
    93  
    94  func scan(w http.ResponseWriter, r *http.Request, q *sql.Row, v ...interface{}) bool {
    95  	switch err := q.Scan(v...); err {
    96  	case nil:
    97  	case sql.ErrNoRows:
    98  		http.NotFound(w, r)
    99  		return false
   100  	default:
   101  		log.Println(err)
   102  		w.WriteHeader(500)
   103  		return false
   104  	}
   105  	return true
   106  }
   107  
   108  func lookupCurRel(w http.ResponseWriter, r *http.Request, plat, cmd string) (v release, ok bool) {
   109  	v.Cmd = cmd
   110  	v.Plat = plat
   111  	const s = `select c.curver, r.sha256 from cur c, release r
   112  				where c.plat=$1 and c.cmd=$2
   113  				and c.plat = r.plat and c.cmd = r.cmd and c.curver = r.ver`
   114  	ok = scan(w, r, db.QueryRow(s, plat, cmd), &v.Ver, &v.Sha256)
   115  	return
   116  }
   117  
   118  func initial(w http.ResponseWriter, r *http.Request) {
   119  	cmd := mux.Vars(r)["cmd"]
   120  	plat := guessPlat(r.UserAgent())
   121  	if rel, ok := lookupCurRel(w, r, plat, cmd); ok {
   122  		url := s3DistURL + rel.Gzname()
   123  		http.Redirect(w, r, url, http.StatusTemporaryRedirect)
   124  	}
   125  }
   126  
   127  func curInfo(w http.ResponseWriter, r *http.Request) {
   128  	q := mux.Vars(r)
   129  	if rel, ok := lookupCurRel(w, r, q["plat"], q["cmd"]); ok {
   130  		logErr(json.NewEncoder(w).Encode(rel))
   131  	}
   132  }
   133  
   134  func getHash(w http.ResponseWriter, r *http.Request) {
   135  	q := mux.Vars(r)
   136  	var info jsonsha
   137  	const s = `select sha256 from release where plat=$1 and cmd=$2 and ver=$3`
   138  	if scan(w, r, db.QueryRow(s, q["plat"], q["cmd"], q["ver"]), &info.Sha256) {
   139  		logErr(json.NewEncoder(w).Encode(info))
   140  	}
   141  }
   142  
   143  func listReleases(w http.ResponseWriter, r *http.Request) {
   144  	rels := make([]release, 0)
   145  	rows, err := db.Query(`select plat, cmd, ver, sha256 from release`)
   146  	if err != nil {
   147  		log.Println(err)
   148  		http.Error(w, "internal error", 500)
   149  		return
   150  	}
   151  	for rows.Next() {
   152  		var rel release
   153  		err := rows.Scan(&rel.Plat, &rel.Cmd, &rel.Ver, &rel.Sha256)
   154  		if err != nil {
   155  			log.Println(err)
   156  		} else {
   157  			rels = append(rels, rel)
   158  		}
   159  	}
   160  	if err := rows.Err(); err != nil {
   161  		log.Println(err)
   162  		http.Error(w, "internal error", 500)
   163  		return
   164  	}
   165  	b := new(bytes.Buffer)
   166  	if err = json.NewEncoder(b).Encode(rels); err != nil {
   167  		log.Println(err)
   168  		http.Error(w, "internal error", 500)
   169  		return
   170  	}
   171  	var mod time.Time
   172  	db.QueryRow(`select t from mod`).Scan(&mod)
   173  	http.ServeContent(w, r, "", mod, bytes.NewReader(b.Bytes()))
   174  }
   175  
   176  func logErr(err error) error {
   177  	if err != nil {
   178  		log.Println(err)
   179  	}
   180  	return err
   181  }
   182  
   183  func isDarwin(ua string) bool {
   184  	return strings.Contains(ua, "mac os x") || strings.Contains(ua, "darwin")
   185  }
   186  
   187  func guessArch(ua string) string {
   188  	if strings.Contains(ua, "x86_64") || strings.Contains(ua, "amd64") || isDarwin(ua) {
   189  		return "amd64"
   190  	}
   191  	return "386"
   192  }
   193  
   194  func guessOS(ua string) string {
   195  	if isDarwin(ua) {
   196  		return "darwin"
   197  	}
   198  	if strings.Contains(ua, "windows") {
   199  		return "windows"
   200  	}
   201  	return "linux"
   202  }
   203  
   204  func guessPlat(ua string) string {
   205  	ua = strings.ToLower(ua)
   206  	return guessOS(ua) + "-" + guessArch(ua)
   207  }
   208  
   209  func putVer(w http.ResponseWriter, r *http.Request) {
   210  	defer r.Body.Close()
   211  	q := mux.Vars(r)
   212  	plat := q["plat"]
   213  	cmd := q["cmd"]
   214  	ver := q["ver"]
   215  	if strings.IndexFunc(plat, badIdentRune) >= 0 ||
   216  		strings.IndexFunc(cmd, badIdentRune) >= 0 ||
   217  		strings.IndexFunc(ver, badVersionRune) >= 0 {
   218  		http.Error(w, "bad character in path", 400)
   219  		return
   220  	}
   221  
   222  	var info jsonsha
   223  	if !readReqJSON(w, r, 1000, &info) {
   224  		return
   225  	}
   226  	if len(info.Sha256) != sha256.Size {
   227  		log.Printf("bad hash length %d != %d", len(info.Sha256), sha256.Size)
   228  		http.Error(w, "unprocessable entity", 422)
   229  		return
   230  	}
   231  
   232  	_, err := db.Exec(`
   233  		insert into release (plat, cmd, ver, sha256)
   234  		values ($1, $2, $3, $4)
   235  	`, plat, cmd, ver, info.Sha256)
   236  	if pe, ok := err.(pq.PGError); ok && pe.Get('C') == pgUniqueViolation {
   237  		http.Error(w, "conflict", http.StatusConflict)
   238  		return
   239  	} else if err != nil {
   240  		log.Println(err)
   241  		http.Error(w, "internal error", 500)
   242  		return
   243  	}
   244  	if _, err = db.Exec(`update mod set t=now()`); err != nil {
   245  		log.Println(err)
   246  		http.Error(w, "internal error", 500)
   247  		return
   248  	}
   249  	w.WriteHeader(http.StatusCreated)
   250  	io.WriteString(w, "created\n")
   251  }
   252  
   253  func readReqJSON(w http.ResponseWriter, r *http.Request, n int64, v interface{}) bool {
   254  	err := json.NewDecoder(http.MaxBytesReader(w, r.Body, n)).Decode(v)
   255  	if err != nil {
   256  		http.Error(w, "unprocessable entity", 422)
   257  	}
   258  	return err == nil
   259  }
   260  
   261  type authenticate struct {
   262  	http.Handler
   263  }
   264  
   265  func (x authenticate) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   266  	hr, _ := http.NewRequest("GET", "https://api.heroku.com/account", nil)
   267  	hr.Header.Set("Accept", "application/vnd.heroku+json; version=3")
   268  	hr.Header.Set("Authorization", r.Header.Get("Authorization"))
   269  	res, err := http.DefaultClient.Do(hr)
   270  	if err != nil {
   271  		log.Println(err)
   272  		http.Error(w, "internal error", 500)
   273  		return
   274  	}
   275  	if res.StatusCode == 401 {
   276  		http.Error(w, "unauthorized", 401)
   277  		return
   278  	}
   279  	if res.StatusCode != 200 {
   280  		log.Println("unexpected status from heroku api:", res.StatusCode)
   281  		http.Error(w, "internal error", 500)
   282  		return
   283  	}
   284  
   285  	var info struct {
   286  		Email string
   287  	}
   288  	err = json.NewDecoder(res.Body).Decode(&info)
   289  	res.Body.Close()
   290  	if err != nil {
   291  		log.Println(err)
   292  		http.Error(w, "internal error", 500)
   293  		return
   294  	}
   295  
   296  	r.Header.Set(":email", info.Email)
   297  	x.Handler.ServeHTTP(w, r)
   298  }
   299  
   300  type herokaiOnly struct {
   301  	http.Handler
   302  }
   303  
   304  func (x herokaiOnly) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   305  	if !strings.HasSuffix(r.Header.Get(":email"), "@heroku.com") {
   306  		http.Error(w, "unauthorized", 401)
   307  		return
   308  	}
   309  	x.Handler.ServeHTTP(w, r)
   310  }
   311  
   312  func mustExec(q string) {
   313  	if _, err := db.Exec(q); err != nil {
   314  		log.Fatal(err)
   315  	}
   316  }
   317  
   318  func initwebdb() {
   319  	connstr, err := pq.ParseURL(os.Getenv("DATABASE_URL"))
   320  	if err != nil {
   321  		log.Fatal("pq.ParseURL", err)
   322  	}
   323  	db, err = sql.Open("postgres", connstr+" sslmode=disable")
   324  	if err != nil {
   325  		log.Fatal("sql.Open", err)
   326  	}
   327  	mustExec(`SET bytea_output = 'hex'`) // work around https://github.com/bmizerany/pq/issues/76
   328  	mustExec(`create table if not exists release (
   329  		plat text not null,
   330  		cmd text not null,
   331  		ver text not null,
   332  		sha256 bytea not null,
   333  		primary key (plat, cmd, ver)
   334  	)`)
   335  	mustExec(`create table if not exists cur (
   336  		plat text not null,
   337  		cmd text not null,
   338  		curver text not null,
   339  		foreign key (plat, cmd, curver) references release (plat, cmd, ver),
   340  		primary key (plat, cmd)
   341  	)`)
   342  	mustExec(`create table if not exists mod (
   343  		t timestamptz not null
   344  	)`)
   345  	mustExec(`insert into mod (t)
   346  		select now()
   347  		where not exists (select 1 from mod)
   348  	`)
   349  }
   350  
   351  func badIdentRune(r rune) bool {
   352  	return !(r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-')
   353  }
   354  
   355  func badVersionRune(r rune) bool {
   356  	return !(r >= '0' && r <= '9' || r == '.')
   357  }