github.com/llimllib/devd@v0.0.0-20230426145215-4d29fc25f909/server.go (about)

     1  package devd
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"html/template"
     7  	"net"
     8  	"net/http"
     9  	"os"
    10  	"os/signal"
    11  	"regexp"
    12  	"strings"
    13  	"syscall"
    14  	"time"
    15  
    16  	"golang.org/x/net/context"
    17  
    18  	rice "github.com/GeertJohan/go.rice"
    19  	"github.com/goji/httpauth"
    20  
    21  	"github.com/cortesi/termlog"
    22  	"github.com/llimllib/devd/httpctx"
    23  	"github.com/llimllib/devd/inject"
    24  	"github.com/llimllib/devd/livereload"
    25  	"github.com/llimllib/devd/ricetemp"
    26  	"github.com/llimllib/devd/slowdown"
    27  	"github.com/llimllib/devd/timer"
    28  )
    29  
    30  const (
    31  	// Version is the current version of devd
    32  	Version  = "0.9"
    33  	portLow  = 8000
    34  	portHigh = 10000
    35  )
    36  
    37  func pickPort(addr string, low int, high int, tls bool) (net.Listener, error) {
    38  	firstTry := 80
    39  	if tls {
    40  		firstTry = 443
    41  	}
    42  	hl, err := net.Listen("tcp", fmt.Sprintf("%v:%d", addr, firstTry))
    43  	if err == nil {
    44  		return hl, nil
    45  	}
    46  	for i := low; i < high; i++ {
    47  		hl, err := net.Listen("tcp", fmt.Sprintf("%v:%d", addr, i))
    48  		if err == nil {
    49  			return hl, nil
    50  		}
    51  	}
    52  	return nil, fmt.Errorf("could not find open port")
    53  }
    54  
    55  func getTLSConfig(path string) (t *tls.Config, err error) {
    56  	config := &tls.Config{}
    57  	if config.NextProtos == nil {
    58  		config.NextProtos = []string{"http/1.1"}
    59  	}
    60  	config.Certificates = make([]tls.Certificate, 1)
    61  	config.Certificates[0], err = tls.LoadX509KeyPair(path, path)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	return config, nil
    66  }
    67  
    68  // This filthy hack works in conjunction with hostPortStrip to restore the
    69  // original request host after mux match.
    70  func revertOriginalHost(r *http.Request) {
    71  	original := r.Header.Get("_devd_original_host")
    72  	if original != "" {
    73  		r.Host = original
    74  		r.Header.Del("_devd_original_host")
    75  	}
    76  }
    77  
    78  // We can remove the mangling once this is fixed:
    79  // 		https://github.com/golang/go/issues/10463
    80  func hostPortStrip(next http.Handler) http.Handler {
    81  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    82  		host, _, err := net.SplitHostPort(r.Host)
    83  		if err == nil {
    84  			original := r.Host
    85  			r.Host = host
    86  			r.Header.Set("_devd_original_host", original)
    87  		}
    88  		next.ServeHTTP(w, r)
    89  	})
    90  }
    91  
    92  func matchStringAny(regexps []*regexp.Regexp, s string) bool {
    93  	for _, r := range regexps {
    94  		if r.MatchString(s) {
    95  			return true
    96  		}
    97  	}
    98  	return false
    99  }
   100  
   101  func formatURL(tls bool, httpIP string, port int) string {
   102  	proto := "http"
   103  	if tls {
   104  		proto = "https"
   105  	}
   106  	host := httpIP
   107  	if httpIP == "0.0.0.0" || httpIP == "127.0.0.1" {
   108  		host = "devd.io"
   109  	}
   110  	if port == 443 && tls {
   111  		return fmt.Sprintf("https://%s", host)
   112  	}
   113  	if port == 80 && !tls {
   114  		return fmt.Sprintf("http://%s", host)
   115  	}
   116  	return fmt.Sprintf("%s://%s:%d", proto, host, port)
   117  }
   118  
   119  // Credentials is a simple username/password pair
   120  type Credentials struct {
   121  	username string
   122  	password string
   123  }
   124  
   125  // CredentialsFromSpec creates a set of credentials from a spec
   126  func CredentialsFromSpec(spec string) (*Credentials, error) {
   127  	parts := strings.SplitN(spec, ":", 2)
   128  	if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
   129  		return nil, fmt.Errorf("invalid credential spec: %s", spec)
   130  	}
   131  	return &Credentials{parts[0], parts[1]}, nil
   132  }
   133  
   134  // Devd represents the devd server options
   135  type Devd struct {
   136  	Routes RouteCollection
   137  
   138  	// Shaping
   139  	Latency       int
   140  	DownKbps      uint
   141  	UpKbps        uint
   142  	ServingScheme string
   143  
   144  	// Add headers
   145  	AddHeaders *http.Header
   146  
   147  	// Livereload and watch static routes
   148  	LivereloadRoutes bool
   149  	// Livereload, but don't watch static routes
   150  	Livereload bool
   151  	WatchPaths []string
   152  	Excludes   []string
   153  
   154  	// Add Access-Control-Allow-Origin header
   155  	Cors bool
   156  
   157  	// Logging
   158  	IgnoreLogs []*regexp.Regexp
   159  
   160  	// Password protection
   161  	Credentials *Credentials
   162  
   163  	lrserver *livereload.Server
   164  }
   165  
   166  // WrapHandler wraps an httpctx.Handler in the paraphernalia needed by devd for
   167  // logging, latency, and so forth.
   168  func (dd *Devd) WrapHandler(log termlog.TermLog, next httpctx.Handler) http.Handler {
   169  	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   170  		r.URL.Scheme = dd.ServingScheme
   171  		revertOriginalHost(r)
   172  		timr := timer.Timer{}
   173  		sublog := log.Group()
   174  		defer func() {
   175  			timing := termlog.DefaultPalette.Timestamp.SprintFunc()("timing: ")
   176  			sublog.SayAs("timer", timing+timr.String())
   177  			sublog.Done()
   178  		}()
   179  		if matchStringAny(dd.IgnoreLogs, fmt.Sprintf("%s%s", r.URL.Host, r.RequestURI)) {
   180  			sublog.Quiet()
   181  		}
   182  		timr.RequestHeaders()
   183  		time.Sleep(time.Millisecond * time.Duration(dd.Latency))
   184  
   185  		dpath := r.RequestURI
   186  		if !strings.HasPrefix(dpath, "/") {
   187  			dpath = "/" + dpath
   188  		}
   189  		sublog.Say("%s %s", r.Method, dpath)
   190  		LogHeader(sublog, r.Header)
   191  		ctx := timr.NewContext(context.Background())
   192  		ctx = termlog.NewContext(ctx, sublog)
   193  		if dd.AddHeaders != nil {
   194  			for h, vals := range *dd.AddHeaders {
   195  				for _, v := range vals {
   196  					w.Header().Set(h, v)
   197  				}
   198  			}
   199  		}
   200  		if dd.Cors {
   201  			origin := r.Header.Get("Origin")
   202  			if origin == "" {
   203  				origin = "*"
   204  			}
   205  			w.Header().Set("Access-Control-Allow-Origin", origin)
   206  			requestHeaders := r.Header.Get("Access-Control-Request-Headers")
   207  			if requestHeaders != "" {
   208  				w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
   209  			}
   210  			requestMethod := r.Header.Get("Access-Control-Request-Method")
   211  			if requestMethod != "" {
   212  				w.Header().Set("Access-Control-Allow-Methods", requestMethod)
   213  			}
   214  
   215  			// required for SharedArrayBuffer usage
   216  			// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/SharedArrayBuffer
   217  			w.Header().Set("Cross-Origin-Opener-Policy", "same-origin")
   218  			w.Header().Set("Cross-Origin-Embedder-Policy", "require-corp")
   219  		}
   220  		flusher, _ := w.(http.Flusher)
   221  		next.ServeHTTPContext(
   222  			ctx,
   223  			&ResponseLogWriter{Log: sublog, Resp: w, Flusher: flusher, Timer: &timr},
   224  			r,
   225  		)
   226  	})
   227  	return h
   228  }
   229  
   230  // HasLivereload tells us if livereload is enabled
   231  func (dd *Devd) HasLivereload() bool {
   232  	if dd.Livereload || dd.LivereloadRoutes || len(dd.WatchPaths) > 0 {
   233  		return true
   234  	}
   235  	return false
   236  }
   237  
   238  // AddRoutes adds route specifications to the server
   239  func (dd *Devd) AddRoutes(specs []string, notfound []string) error {
   240  	dd.Routes = make(RouteCollection)
   241  	for _, s := range specs {
   242  		err := dd.Routes.Add(s, notfound)
   243  		if err != nil {
   244  			return fmt.Errorf("invalid route specification: %s", err)
   245  		}
   246  	}
   247  	return nil
   248  }
   249  
   250  // AddIgnores adds log ignore patterns to the server
   251  func (dd *Devd) AddIgnores(specs []string) error {
   252  	dd.IgnoreLogs = make([]*regexp.Regexp, 0)
   253  	for _, expr := range specs {
   254  		v, err := regexp.Compile(expr)
   255  		if err != nil {
   256  			return fmt.Errorf("%s", err)
   257  		}
   258  		dd.IgnoreLogs = append(dd.IgnoreLogs, v)
   259  	}
   260  	return nil
   261  }
   262  
   263  // HandleNotFound handles pages not found. In particular, this handler is used
   264  // when we have no matching route for a request. This also means it's not
   265  // useful to inject the livereload paraphernalia here.
   266  func HandleNotFound(templates *template.Template) httpctx.Handler {
   267  	return httpctx.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, _ *http.Request) {
   268  		w.WriteHeader(http.StatusNotFound)
   269  		err := templates.Lookup("404.html").Execute(w, nil)
   270  		if err != nil {
   271  			logger := termlog.FromContext(ctx)
   272  			logger.Shout("Could not execute template: %s", err)
   273  		}
   274  	})
   275  }
   276  
   277  // Router constructs the main Devd router that serves all requests
   278  func (dd *Devd) Router(logger termlog.TermLog, templates *template.Template) (http.Handler, error) {
   279  	mux := http.NewServeMux()
   280  	hasGlobal := false
   281  
   282  	ci := inject.CopyInject{}
   283  	if dd.HasLivereload() {
   284  		ci = livereload.Injector
   285  	}
   286  
   287  	for match, route := range dd.Routes {
   288  		if match == "/" {
   289  			hasGlobal = true
   290  		}
   291  		handler := dd.WrapHandler(
   292  			logger,
   293  			route.Endpoint.Handler(route.Path, templates, ci),
   294  		)
   295  		mux.Handle(match, handler)
   296  	}
   297  	if dd.HasLivereload() {
   298  		lr := livereload.NewServer("livereload", logger)
   299  		mux.Handle(livereload.EndpointPath, lr)
   300  		mux.Handle(livereload.ScriptPath, http.HandlerFunc(lr.ServeScript))
   301  		seen := make(map[string]bool)
   302  		for _, route := range dd.Routes {
   303  			if _, ok := seen[route.Host]; route.Host != "" && !ok {
   304  				mux.Handle(route.Host+livereload.EndpointPath, lr)
   305  				mux.Handle(
   306  					route.Host+livereload.ScriptPath,
   307  					http.HandlerFunc(lr.ServeScript),
   308  				)
   309  				seen[route.Host] = true
   310  			}
   311  		}
   312  		if dd.LivereloadRoutes {
   313  			err := WatchRoutes(dd.Routes, lr, dd.Excludes, logger)
   314  			if err != nil {
   315  				return nil, fmt.Errorf("could not watch routes for livereload: %s", err)
   316  			}
   317  		}
   318  		if len(dd.WatchPaths) > 0 {
   319  			err := WatchPaths(dd.WatchPaths, dd.Excludes, lr, logger)
   320  			if err != nil {
   321  				return nil, fmt.Errorf("could not watch path for livereload: %s", err)
   322  			}
   323  		}
   324  		dd.lrserver = lr
   325  	}
   326  	if !hasGlobal {
   327  		mux.Handle(
   328  			"/",
   329  			dd.WrapHandler(logger, HandleNotFound(templates)),
   330  		)
   331  	}
   332  	h := http.Handler(mux)
   333  	if dd.Credentials != nil {
   334  		h = httpauth.SimpleBasicAuth(
   335  			dd.Credentials.username, dd.Credentials.password,
   336  		)(h)
   337  	}
   338  	return hostPortStrip(h), nil
   339  }
   340  
   341  // Serve starts the devd server. The callback is called with the serving URL
   342  // just before service starts.
   343  func (dd *Devd) Serve(address string, port int, certFile string, logger termlog.TermLog, callback func(string)) error {
   344  	templates, err := ricetemp.MakeTemplates(rice.MustFindBox("templates"))
   345  	if err != nil {
   346  		return fmt.Errorf("error loading templates: %s", err)
   347  	}
   348  	mux, err := dd.Router(logger, templates)
   349  	if err != nil {
   350  		return err
   351  	}
   352  	var tlsConfig *tls.Config
   353  	var tlsEnabled bool
   354  	if certFile != "" {
   355  		tlsConfig, err = getTLSConfig(certFile)
   356  		if err != nil {
   357  			return fmt.Errorf("could not load certs: %s", err)
   358  		}
   359  		tlsEnabled = true
   360  	}
   361  
   362  	var hl net.Listener
   363  	if port > 0 {
   364  		hl, err = net.Listen("tcp", fmt.Sprintf("%v:%d", address, port))
   365  	} else {
   366  		hl, err = pickPort(address, portLow, portHigh, tlsEnabled)
   367  	}
   368  	if err != nil {
   369  		return err
   370  	}
   371  
   372  	if tlsConfig != nil {
   373  		hl = tls.NewListener(hl, tlsConfig)
   374  	}
   375  
   376  	hl = slowdown.NewSlowListener(hl, dd.UpKbps*1024, dd.DownKbps*1024)
   377  	url := formatURL(tlsEnabled, address, hl.Addr().(*net.TCPAddr).Port)
   378  	logger.Say("Listening on %s (%s)", url, hl.Addr().String())
   379  	server := &http.Server{Addr: hl.Addr().String(), Handler: mux}
   380  	callback(url)
   381  
   382  	if dd.HasLivereload() {
   383  		c := make(chan os.Signal, 1)
   384  		signal.Notify(c, syscall.SIGHUP)
   385  		go func() {
   386  			for {
   387  				<-c
   388  				logger.Say("Received signal - reloading")
   389  				dd.lrserver.Reload([]string{"*"})
   390  			}
   391  		}()
   392  	}
   393  
   394  	err = server.Serve(hl)
   395  	logger.Shout("Server stopped: %v", err)
   396  	return nil
   397  }