github.com/llimllib/devd@v0.0.0-20230426145215-4d29fc25f909/server.go (about) 1 package devd 2 3 import ( 4 "crypto/tls" 5 "fmt" 6 "html/template" 7 "net" 8 "net/http" 9 "os" 10 "os/signal" 11 "regexp" 12 "strings" 13 "syscall" 14 "time" 15 16 "golang.org/x/net/context" 17 18 rice "github.com/GeertJohan/go.rice" 19 "github.com/goji/httpauth" 20 21 "github.com/cortesi/termlog" 22 "github.com/llimllib/devd/httpctx" 23 "github.com/llimllib/devd/inject" 24 "github.com/llimllib/devd/livereload" 25 "github.com/llimllib/devd/ricetemp" 26 "github.com/llimllib/devd/slowdown" 27 "github.com/llimllib/devd/timer" 28 ) 29 30 const ( 31 // Version is the current version of devd 32 Version = "0.9" 33 portLow = 8000 34 portHigh = 10000 35 ) 36 37 func pickPort(addr string, low int, high int, tls bool) (net.Listener, error) { 38 firstTry := 80 39 if tls { 40 firstTry = 443 41 } 42 hl, err := net.Listen("tcp", fmt.Sprintf("%v:%d", addr, firstTry)) 43 if err == nil { 44 return hl, nil 45 } 46 for i := low; i < high; i++ { 47 hl, err := net.Listen("tcp", fmt.Sprintf("%v:%d", addr, i)) 48 if err == nil { 49 return hl, nil 50 } 51 } 52 return nil, fmt.Errorf("could not find open port") 53 } 54 55 func getTLSConfig(path string) (t *tls.Config, err error) { 56 config := &tls.Config{} 57 if config.NextProtos == nil { 58 config.NextProtos = []string{"http/1.1"} 59 } 60 config.Certificates = make([]tls.Certificate, 1) 61 config.Certificates[0], err = tls.LoadX509KeyPair(path, path) 62 if err != nil { 63 return nil, err 64 } 65 return config, nil 66 } 67 68 // This filthy hack works in conjunction with hostPortStrip to restore the 69 // original request host after mux match. 70 func revertOriginalHost(r *http.Request) { 71 original := r.Header.Get("_devd_original_host") 72 if original != "" { 73 r.Host = original 74 r.Header.Del("_devd_original_host") 75 } 76 } 77 78 // We can remove the mangling once this is fixed: 79 // https://github.com/golang/go/issues/10463 80 func hostPortStrip(next http.Handler) http.Handler { 81 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 82 host, _, err := net.SplitHostPort(r.Host) 83 if err == nil { 84 original := r.Host 85 r.Host = host 86 r.Header.Set("_devd_original_host", original) 87 } 88 next.ServeHTTP(w, r) 89 }) 90 } 91 92 func matchStringAny(regexps []*regexp.Regexp, s string) bool { 93 for _, r := range regexps { 94 if r.MatchString(s) { 95 return true 96 } 97 } 98 return false 99 } 100 101 func formatURL(tls bool, httpIP string, port int) string { 102 proto := "http" 103 if tls { 104 proto = "https" 105 } 106 host := httpIP 107 if httpIP == "0.0.0.0" || httpIP == "127.0.0.1" { 108 host = "devd.io" 109 } 110 if port == 443 && tls { 111 return fmt.Sprintf("https://%s", host) 112 } 113 if port == 80 && !tls { 114 return fmt.Sprintf("http://%s", host) 115 } 116 return fmt.Sprintf("%s://%s:%d", proto, host, port) 117 } 118 119 // Credentials is a simple username/password pair 120 type Credentials struct { 121 username string 122 password string 123 } 124 125 // CredentialsFromSpec creates a set of credentials from a spec 126 func CredentialsFromSpec(spec string) (*Credentials, error) { 127 parts := strings.SplitN(spec, ":", 2) 128 if len(parts) != 2 || parts[0] == "" || parts[1] == "" { 129 return nil, fmt.Errorf("invalid credential spec: %s", spec) 130 } 131 return &Credentials{parts[0], parts[1]}, nil 132 } 133 134 // Devd represents the devd server options 135 type Devd struct { 136 Routes RouteCollection 137 138 // Shaping 139 Latency int 140 DownKbps uint 141 UpKbps uint 142 ServingScheme string 143 144 // Add headers 145 AddHeaders *http.Header 146 147 // Livereload and watch static routes 148 LivereloadRoutes bool 149 // Livereload, but don't watch static routes 150 Livereload bool 151 WatchPaths []string 152 Excludes []string 153 154 // Add Access-Control-Allow-Origin header 155 Cors bool 156 157 // Logging 158 IgnoreLogs []*regexp.Regexp 159 160 // Password protection 161 Credentials *Credentials 162 163 lrserver *livereload.Server 164 } 165 166 // WrapHandler wraps an httpctx.Handler in the paraphernalia needed by devd for 167 // logging, latency, and so forth. 168 func (dd *Devd) WrapHandler(log termlog.TermLog, next httpctx.Handler) http.Handler { 169 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 170 r.URL.Scheme = dd.ServingScheme 171 revertOriginalHost(r) 172 timr := timer.Timer{} 173 sublog := log.Group() 174 defer func() { 175 timing := termlog.DefaultPalette.Timestamp.SprintFunc()("timing: ") 176 sublog.SayAs("timer", timing+timr.String()) 177 sublog.Done() 178 }() 179 if matchStringAny(dd.IgnoreLogs, fmt.Sprintf("%s%s", r.URL.Host, r.RequestURI)) { 180 sublog.Quiet() 181 } 182 timr.RequestHeaders() 183 time.Sleep(time.Millisecond * time.Duration(dd.Latency)) 184 185 dpath := r.RequestURI 186 if !strings.HasPrefix(dpath, "/") { 187 dpath = "/" + dpath 188 } 189 sublog.Say("%s %s", r.Method, dpath) 190 LogHeader(sublog, r.Header) 191 ctx := timr.NewContext(context.Background()) 192 ctx = termlog.NewContext(ctx, sublog) 193 if dd.AddHeaders != nil { 194 for h, vals := range *dd.AddHeaders { 195 for _, v := range vals { 196 w.Header().Set(h, v) 197 } 198 } 199 } 200 if dd.Cors { 201 origin := r.Header.Get("Origin") 202 if origin == "" { 203 origin = "*" 204 } 205 w.Header().Set("Access-Control-Allow-Origin", origin) 206 requestHeaders := r.Header.Get("Access-Control-Request-Headers") 207 if requestHeaders != "" { 208 w.Header().Set("Access-Control-Allow-Headers", requestHeaders) 209 } 210 requestMethod := r.Header.Get("Access-Control-Request-Method") 211 if requestMethod != "" { 212 w.Header().Set("Access-Control-Allow-Methods", requestMethod) 213 } 214 215 // required for SharedArrayBuffer usage 216 // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/SharedArrayBuffer 217 w.Header().Set("Cross-Origin-Opener-Policy", "same-origin") 218 w.Header().Set("Cross-Origin-Embedder-Policy", "require-corp") 219 } 220 flusher, _ := w.(http.Flusher) 221 next.ServeHTTPContext( 222 ctx, 223 &ResponseLogWriter{Log: sublog, Resp: w, Flusher: flusher, Timer: &timr}, 224 r, 225 ) 226 }) 227 return h 228 } 229 230 // HasLivereload tells us if livereload is enabled 231 func (dd *Devd) HasLivereload() bool { 232 if dd.Livereload || dd.LivereloadRoutes || len(dd.WatchPaths) > 0 { 233 return true 234 } 235 return false 236 } 237 238 // AddRoutes adds route specifications to the server 239 func (dd *Devd) AddRoutes(specs []string, notfound []string) error { 240 dd.Routes = make(RouteCollection) 241 for _, s := range specs { 242 err := dd.Routes.Add(s, notfound) 243 if err != nil { 244 return fmt.Errorf("invalid route specification: %s", err) 245 } 246 } 247 return nil 248 } 249 250 // AddIgnores adds log ignore patterns to the server 251 func (dd *Devd) AddIgnores(specs []string) error { 252 dd.IgnoreLogs = make([]*regexp.Regexp, 0) 253 for _, expr := range specs { 254 v, err := regexp.Compile(expr) 255 if err != nil { 256 return fmt.Errorf("%s", err) 257 } 258 dd.IgnoreLogs = append(dd.IgnoreLogs, v) 259 } 260 return nil 261 } 262 263 // HandleNotFound handles pages not found. In particular, this handler is used 264 // when we have no matching route for a request. This also means it's not 265 // useful to inject the livereload paraphernalia here. 266 func HandleNotFound(templates *template.Template) httpctx.Handler { 267 return httpctx.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, _ *http.Request) { 268 w.WriteHeader(http.StatusNotFound) 269 err := templates.Lookup("404.html").Execute(w, nil) 270 if err != nil { 271 logger := termlog.FromContext(ctx) 272 logger.Shout("Could not execute template: %s", err) 273 } 274 }) 275 } 276 277 // Router constructs the main Devd router that serves all requests 278 func (dd *Devd) Router(logger termlog.TermLog, templates *template.Template) (http.Handler, error) { 279 mux := http.NewServeMux() 280 hasGlobal := false 281 282 ci := inject.CopyInject{} 283 if dd.HasLivereload() { 284 ci = livereload.Injector 285 } 286 287 for match, route := range dd.Routes { 288 if match == "/" { 289 hasGlobal = true 290 } 291 handler := dd.WrapHandler( 292 logger, 293 route.Endpoint.Handler(route.Path, templates, ci), 294 ) 295 mux.Handle(match, handler) 296 } 297 if dd.HasLivereload() { 298 lr := livereload.NewServer("livereload", logger) 299 mux.Handle(livereload.EndpointPath, lr) 300 mux.Handle(livereload.ScriptPath, http.HandlerFunc(lr.ServeScript)) 301 seen := make(map[string]bool) 302 for _, route := range dd.Routes { 303 if _, ok := seen[route.Host]; route.Host != "" && !ok { 304 mux.Handle(route.Host+livereload.EndpointPath, lr) 305 mux.Handle( 306 route.Host+livereload.ScriptPath, 307 http.HandlerFunc(lr.ServeScript), 308 ) 309 seen[route.Host] = true 310 } 311 } 312 if dd.LivereloadRoutes { 313 err := WatchRoutes(dd.Routes, lr, dd.Excludes, logger) 314 if err != nil { 315 return nil, fmt.Errorf("could not watch routes for livereload: %s", err) 316 } 317 } 318 if len(dd.WatchPaths) > 0 { 319 err := WatchPaths(dd.WatchPaths, dd.Excludes, lr, logger) 320 if err != nil { 321 return nil, fmt.Errorf("could not watch path for livereload: %s", err) 322 } 323 } 324 dd.lrserver = lr 325 } 326 if !hasGlobal { 327 mux.Handle( 328 "/", 329 dd.WrapHandler(logger, HandleNotFound(templates)), 330 ) 331 } 332 h := http.Handler(mux) 333 if dd.Credentials != nil { 334 h = httpauth.SimpleBasicAuth( 335 dd.Credentials.username, dd.Credentials.password, 336 )(h) 337 } 338 return hostPortStrip(h), nil 339 } 340 341 // Serve starts the devd server. The callback is called with the serving URL 342 // just before service starts. 343 func (dd *Devd) Serve(address string, port int, certFile string, logger termlog.TermLog, callback func(string)) error { 344 templates, err := ricetemp.MakeTemplates(rice.MustFindBox("templates")) 345 if err != nil { 346 return fmt.Errorf("error loading templates: %s", err) 347 } 348 mux, err := dd.Router(logger, templates) 349 if err != nil { 350 return err 351 } 352 var tlsConfig *tls.Config 353 var tlsEnabled bool 354 if certFile != "" { 355 tlsConfig, err = getTLSConfig(certFile) 356 if err != nil { 357 return fmt.Errorf("could not load certs: %s", err) 358 } 359 tlsEnabled = true 360 } 361 362 var hl net.Listener 363 if port > 0 { 364 hl, err = net.Listen("tcp", fmt.Sprintf("%v:%d", address, port)) 365 } else { 366 hl, err = pickPort(address, portLow, portHigh, tlsEnabled) 367 } 368 if err != nil { 369 return err 370 } 371 372 if tlsConfig != nil { 373 hl = tls.NewListener(hl, tlsConfig) 374 } 375 376 hl = slowdown.NewSlowListener(hl, dd.UpKbps*1024, dd.DownKbps*1024) 377 url := formatURL(tlsEnabled, address, hl.Addr().(*net.TCPAddr).Port) 378 logger.Say("Listening on %s (%s)", url, hl.Addr().String()) 379 server := &http.Server{Addr: hl.Addr().String(), Handler: mux} 380 callback(url) 381 382 if dd.HasLivereload() { 383 c := make(chan os.Signal, 1) 384 signal.Notify(c, syscall.SIGHUP) 385 go func() { 386 for { 387 <-c 388 logger.Say("Received signal - reloading") 389 dd.lrserver.Reload([]string{"*"}) 390 } 391 }() 392 } 393 394 err = server.Serve(hl) 395 logger.Shout("Server stopped: %v", err) 396 return nil 397 }