github.com/koron/hk@v0.0.0-20150303213137-b8aeaa3ab34c/hkdist/web.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "crypto/sha256" 6 "database/sql" 7 "encoding/json" 8 "io" 9 "log" 10 "net/http" 11 "os" 12 "strings" 13 "time" 14 15 "github.com/heroku/hk/Godeps/_workspace/src/github.com/bmizerany/pq" 16 "github.com/heroku/hk/Godeps/_workspace/src/github.com/gorilla/mux" 17 "github.com/heroku/hk/Godeps/_workspace/src/github.com/kr/secureheader" 18 ) 19 20 const ( 21 pgUniqueViolation = "23505" 22 ) 23 24 var db *sql.DB 25 26 // Examples: 27 // 28 // PUT /hk-1-linux-386.json 29 // PUT /hk-linux-386.json 30 // 31 // GET /hk-current-linux-386.json 32 // GET /hk-1-linux-386.json 33 // GET /hk.gz 34 func web(args []string) { 35 mustHaveEnv("DATABASE_URL") 36 initwebdb() 37 r := mux.NewRouter() 38 r.HandleFunc("/{cmd}.gz", http.HandlerFunc(initial)).Methods("GET", "HEAD") 39 r.HandleFunc("/{cmd}/current/{plat}.json", http.HandlerFunc(curInfo)).Methods("GET", "HEAD") 40 r.HandleFunc("/{cmd}/{ver}/{plat}.json", http.HandlerFunc(getHash)).Methods("GET", "HEAD") 41 r.HandleFunc("/release.json", http.HandlerFunc(listReleases)).Methods("GET", "HEAD") 42 r.Path("/{cmd}/current/{plat}.json").Methods("PUT").Handler(authenticate{herokaiOnly{http.HandlerFunc(setCur)}}) 43 r.Path("/{cmd}/{ver}/{plat}.json").Methods("PUT").Handler(authenticate{herokaiOnly{http.HandlerFunc(putVer)}}) 44 r.PathPrefix("/").Methods("GET", "HEAD").Handler(http.FileServer(http.Dir("hkdist/public"))) 45 http.Handle("/", r) 46 secureheader.DefaultConfig.PermitClearLoopback = true 47 err := http.ListenAndServe(":"+os.Getenv("PORT"), secureheader.DefaultConfig) 48 if err != nil { 49 log.Fatalf(`{"func":"ListenAndServe", "error":%q}`, err) 50 } 51 } 52 53 func setCur(w http.ResponseWriter, r *http.Request) { 54 defer r.Body.Close() 55 q := mux.Vars(r) 56 plat := q["plat"] 57 cmd := q["cmd"] 58 if strings.IndexFunc(plat, badIdentRune) >= 0 || 59 strings.IndexFunc(cmd, badIdentRune) >= 0 { 60 http.Error(w, "bad character in path", 400) 61 return 62 } 63 64 var info struct{ Version string } 65 if !readReqJSON(w, r, 1000, &info) { 66 return 67 } 68 _, err := db.Exec(` 69 update cur set curver=$1 where plat=$2 and cmd=$3 70 `, info.Version, plat, cmd) 71 if err != nil { 72 log.Println(err) 73 http.Error(w, "internal error", 500) 74 return 75 } 76 _, err = db.Exec(` 77 insert into cur (plat, cmd, curver) 78 select $1, $2, $3 79 where not exists (select 1 from cur where plat=$1 and cmd=$2) 80 `, plat, cmd, info.Version) 81 if err != nil { 82 log.Println(err) 83 http.Error(w, "internal error", 500) 84 return 85 } 86 if _, err = db.Exec(`update mod set t=now()`); err != nil { 87 log.Println(err) 88 http.Error(w, "internal error", 500) 89 return 90 } 91 io.WriteString(w, "ok\n") 92 } 93 94 func scan(w http.ResponseWriter, r *http.Request, q *sql.Row, v ...interface{}) bool { 95 switch err := q.Scan(v...); err { 96 case nil: 97 case sql.ErrNoRows: 98 http.NotFound(w, r) 99 return false 100 default: 101 log.Println(err) 102 w.WriteHeader(500) 103 return false 104 } 105 return true 106 } 107 108 func lookupCurRel(w http.ResponseWriter, r *http.Request, plat, cmd string) (v release, ok bool) { 109 v.Cmd = cmd 110 v.Plat = plat 111 const s = `select c.curver, r.sha256 from cur c, release r 112 where c.plat=$1 and c.cmd=$2 113 and c.plat = r.plat and c.cmd = r.cmd and c.curver = r.ver` 114 ok = scan(w, r, db.QueryRow(s, plat, cmd), &v.Ver, &v.Sha256) 115 return 116 } 117 118 func initial(w http.ResponseWriter, r *http.Request) { 119 cmd := mux.Vars(r)["cmd"] 120 plat := guessPlat(r.UserAgent()) 121 if rel, ok := lookupCurRel(w, r, plat, cmd); ok { 122 url := s3DistURL + rel.Gzname() 123 http.Redirect(w, r, url, http.StatusTemporaryRedirect) 124 } 125 } 126 127 func curInfo(w http.ResponseWriter, r *http.Request) { 128 q := mux.Vars(r) 129 if rel, ok := lookupCurRel(w, r, q["plat"], q["cmd"]); ok { 130 logErr(json.NewEncoder(w).Encode(rel)) 131 } 132 } 133 134 func getHash(w http.ResponseWriter, r *http.Request) { 135 q := mux.Vars(r) 136 var info jsonsha 137 const s = `select sha256 from release where plat=$1 and cmd=$2 and ver=$3` 138 if scan(w, r, db.QueryRow(s, q["plat"], q["cmd"], q["ver"]), &info.Sha256) { 139 logErr(json.NewEncoder(w).Encode(info)) 140 } 141 } 142 143 func listReleases(w http.ResponseWriter, r *http.Request) { 144 rels := make([]release, 0) 145 rows, err := db.Query(`select plat, cmd, ver, sha256 from release`) 146 if err != nil { 147 log.Println(err) 148 http.Error(w, "internal error", 500) 149 return 150 } 151 for rows.Next() { 152 var rel release 153 err := rows.Scan(&rel.Plat, &rel.Cmd, &rel.Ver, &rel.Sha256) 154 if err != nil { 155 log.Println(err) 156 } else { 157 rels = append(rels, rel) 158 } 159 } 160 if err := rows.Err(); err != nil { 161 log.Println(err) 162 http.Error(w, "internal error", 500) 163 return 164 } 165 b := new(bytes.Buffer) 166 if err = json.NewEncoder(b).Encode(rels); err != nil { 167 log.Println(err) 168 http.Error(w, "internal error", 500) 169 return 170 } 171 var mod time.Time 172 db.QueryRow(`select t from mod`).Scan(&mod) 173 http.ServeContent(w, r, "", mod, bytes.NewReader(b.Bytes())) 174 } 175 176 func logErr(err error) error { 177 if err != nil { 178 log.Println(err) 179 } 180 return err 181 } 182 183 func isDarwin(ua string) bool { 184 return strings.Contains(ua, "mac os x") || strings.Contains(ua, "darwin") 185 } 186 187 func guessArch(ua string) string { 188 if strings.Contains(ua, "x86_64") || strings.Contains(ua, "amd64") || isDarwin(ua) { 189 return "amd64" 190 } 191 return "386" 192 } 193 194 func guessOS(ua string) string { 195 if isDarwin(ua) { 196 return "darwin" 197 } 198 if strings.Contains(ua, "windows") { 199 return "windows" 200 } 201 return "linux" 202 } 203 204 func guessPlat(ua string) string { 205 ua = strings.ToLower(ua) 206 return guessOS(ua) + "-" + guessArch(ua) 207 } 208 209 func putVer(w http.ResponseWriter, r *http.Request) { 210 defer r.Body.Close() 211 q := mux.Vars(r) 212 plat := q["plat"] 213 cmd := q["cmd"] 214 ver := q["ver"] 215 if strings.IndexFunc(plat, badIdentRune) >= 0 || 216 strings.IndexFunc(cmd, badIdentRune) >= 0 || 217 strings.IndexFunc(ver, badVersionRune) >= 0 { 218 http.Error(w, "bad character in path", 400) 219 return 220 } 221 222 var info jsonsha 223 if !readReqJSON(w, r, 1000, &info) { 224 return 225 } 226 if len(info.Sha256) != sha256.Size { 227 log.Printf("bad hash length %d != %d", len(info.Sha256), sha256.Size) 228 http.Error(w, "unprocessable entity", 422) 229 return 230 } 231 232 _, err := db.Exec(` 233 insert into release (plat, cmd, ver, sha256) 234 values ($1, $2, $3, $4) 235 `, plat, cmd, ver, info.Sha256) 236 if pe, ok := err.(pq.PGError); ok && pe.Get('C') == pgUniqueViolation { 237 http.Error(w, "conflict", http.StatusConflict) 238 return 239 } else if err != nil { 240 log.Println(err) 241 http.Error(w, "internal error", 500) 242 return 243 } 244 if _, err = db.Exec(`update mod set t=now()`); err != nil { 245 log.Println(err) 246 http.Error(w, "internal error", 500) 247 return 248 } 249 w.WriteHeader(http.StatusCreated) 250 io.WriteString(w, "created\n") 251 } 252 253 func readReqJSON(w http.ResponseWriter, r *http.Request, n int64, v interface{}) bool { 254 err := json.NewDecoder(http.MaxBytesReader(w, r.Body, n)).Decode(v) 255 if err != nil { 256 http.Error(w, "unprocessable entity", 422) 257 } 258 return err == nil 259 } 260 261 type authenticate struct { 262 http.Handler 263 } 264 265 func (x authenticate) ServeHTTP(w http.ResponseWriter, r *http.Request) { 266 hr, _ := http.NewRequest("GET", "https://api.heroku.com/account", nil) 267 hr.Header.Set("Accept", "application/vnd.heroku+json; version=3") 268 hr.Header.Set("Authorization", r.Header.Get("Authorization")) 269 res, err := http.DefaultClient.Do(hr) 270 if err != nil { 271 log.Println(err) 272 http.Error(w, "internal error", 500) 273 return 274 } 275 if res.StatusCode == 401 { 276 http.Error(w, "unauthorized", 401) 277 return 278 } 279 if res.StatusCode != 200 { 280 log.Println("unexpected status from heroku api:", res.StatusCode) 281 http.Error(w, "internal error", 500) 282 return 283 } 284 285 var info struct { 286 Email string 287 } 288 err = json.NewDecoder(res.Body).Decode(&info) 289 res.Body.Close() 290 if err != nil { 291 log.Println(err) 292 http.Error(w, "internal error", 500) 293 return 294 } 295 296 r.Header.Set(":email", info.Email) 297 x.Handler.ServeHTTP(w, r) 298 } 299 300 type herokaiOnly struct { 301 http.Handler 302 } 303 304 func (x herokaiOnly) ServeHTTP(w http.ResponseWriter, r *http.Request) { 305 if !strings.HasSuffix(r.Header.Get(":email"), "@heroku.com") { 306 http.Error(w, "unauthorized", 401) 307 return 308 } 309 x.Handler.ServeHTTP(w, r) 310 } 311 312 func mustExec(q string) { 313 if _, err := db.Exec(q); err != nil { 314 log.Fatal(err) 315 } 316 } 317 318 func initwebdb() { 319 connstr, err := pq.ParseURL(os.Getenv("DATABASE_URL")) 320 if err != nil { 321 log.Fatal("pq.ParseURL", err) 322 } 323 db, err = sql.Open("postgres", connstr+" sslmode=disable") 324 if err != nil { 325 log.Fatal("sql.Open", err) 326 } 327 mustExec(`SET bytea_output = 'hex'`) // work around https://github.com/bmizerany/pq/issues/76 328 mustExec(`create table if not exists release ( 329 plat text not null, 330 cmd text not null, 331 ver text not null, 332 sha256 bytea not null, 333 primary key (plat, cmd, ver) 334 )`) 335 mustExec(`create table if not exists cur ( 336 plat text not null, 337 cmd text not null, 338 curver text not null, 339 foreign key (plat, cmd, curver) references release (plat, cmd, ver), 340 primary key (plat, cmd) 341 )`) 342 mustExec(`create table if not exists mod ( 343 t timestamptz not null 344 )`) 345 mustExec(`insert into mod (t) 346 select now() 347 where not exists (select 1 from mod) 348 `) 349 } 350 351 func badIdentRune(r rune) bool { 352 return !(r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-') 353 } 354 355 func badVersionRune(r rune) bool { 356 return !(r >= '0' && r <= '9' || r == '.') 357 }