github.com/cortesi/devd@v0.0.0-20200427000907-c1a3bfba27d8/server.go (about)

     1  package devd
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"html/template"
     7  	"net"
     8  	"net/http"
     9  	"os"
    10  	"os/signal"
    11  	"regexp"
    12  	"strings"
    13  	"syscall"
    14  	"time"
    15  
    16  	"golang.org/x/net/context"
    17  
    18  	rice "github.com/GeertJohan/go.rice"
    19  	"github.com/goji/httpauth"
    20  
    21  	"github.com/cortesi/devd/httpctx"
    22  	"github.com/cortesi/devd/inject"
    23  	"github.com/cortesi/devd/livereload"
    24  	"github.com/cortesi/devd/ricetemp"
    25  	"github.com/cortesi/devd/slowdown"
    26  	"github.com/cortesi/devd/timer"
    27  	"github.com/cortesi/termlog"
    28  )
    29  
    30  const (
    31  	// Version is the current version of devd
    32  	Version  = "0.9"
    33  	portLow  = 8000
    34  	portHigh = 10000
    35  )
    36  
    37  func pickPort(addr string, low int, high int, tls bool) (net.Listener, error) {
    38  	firstTry := 80
    39  	if tls {
    40  		firstTry = 443
    41  	}
    42  	hl, err := net.Listen("tcp", fmt.Sprintf("%v:%d", addr, firstTry))
    43  	if err == nil {
    44  		return hl, nil
    45  	}
    46  	for i := low; i < high; i++ {
    47  		hl, err := net.Listen("tcp", fmt.Sprintf("%v:%d", addr, i))
    48  		if err == nil {
    49  			return hl, nil
    50  		}
    51  	}
    52  	return nil, fmt.Errorf("Could not find open port.")
    53  }
    54  
    55  func getTLSConfig(path string) (t *tls.Config, err error) {
    56  	config := &tls.Config{}
    57  	if config.NextProtos == nil {
    58  		config.NextProtos = []string{"http/1.1"}
    59  	}
    60  	config.Certificates = make([]tls.Certificate, 1)
    61  	config.Certificates[0], err = tls.LoadX509KeyPair(path, path)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	return config, nil
    66  }
    67  
    68  // This filthy hack works in conjunction with hostPortStrip to restore the
    69  // original request host after mux match.
    70  func revertOriginalHost(r *http.Request) {
    71  	original := r.Header.Get("_devd_original_host")
    72  	if original != "" {
    73  		r.Host = original
    74  		r.Header.Del("_devd_original_host")
    75  	}
    76  }
    77  
    78  // We can remove the mangling once this is fixed:
    79  // 		https://github.com/golang/go/issues/10463
    80  func hostPortStrip(next http.Handler) http.Handler {
    81  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    82  		host, _, err := net.SplitHostPort(r.Host)
    83  		if err == nil {
    84  			original := r.Host
    85  			r.Host = host
    86  			r.Header.Set("_devd_original_host", original)
    87  		}
    88  		next.ServeHTTP(w, r)
    89  	})
    90  }
    91  
    92  func matchStringAny(regexps []*regexp.Regexp, s string) bool {
    93  	for _, r := range regexps {
    94  		if r.MatchString(s) {
    95  			return true
    96  		}
    97  	}
    98  	return false
    99  }
   100  
   101  func formatURL(tls bool, httpIP string, port int) string {
   102  	proto := "http"
   103  	if tls {
   104  		proto = "https"
   105  	}
   106  	host := httpIP
   107  	if httpIP == "0.0.0.0" || httpIP == "127.0.0.1" {
   108  		host = "devd.io"
   109  	}
   110  	if port == 443 && tls {
   111  		return fmt.Sprintf("https://%s", host)
   112  	}
   113  	if port == 80 && !tls {
   114  		return fmt.Sprintf("http://%s", host)
   115  	}
   116  	return fmt.Sprintf("%s://%s:%d", proto, host, port)
   117  }
   118  
   119  // Credentials is a simple username/password pair
   120  type Credentials struct {
   121  	username string
   122  	password string
   123  }
   124  
   125  // CredentialsFromSpec creates a set of credentials from a spec
   126  func CredentialsFromSpec(spec string) (*Credentials, error) {
   127  	parts := strings.SplitN(spec, ":", 2)
   128  	if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
   129  		return nil, fmt.Errorf("Invalid credential spec: %s", spec)
   130  	}
   131  	return &Credentials{parts[0], parts[1]}, nil
   132  }
   133  
   134  // Devd represents the devd server options
   135  type Devd struct {
   136  	Routes RouteCollection
   137  
   138  	// Shaping
   139  	Latency       int
   140  	DownKbps      uint
   141  	UpKbps        uint
   142  	ServingScheme string
   143  
   144  	// Add headers
   145  	AddHeaders *http.Header
   146  
   147  	// Livereload and watch static routes
   148  	LivereloadRoutes bool
   149  	// Livereload, but don't watch static routes
   150  	Livereload bool
   151  	WatchPaths []string
   152  	Excludes   []string
   153  
   154  	// Add Access-Control-Allow-Origin header
   155  	Cors bool
   156  
   157  	// Logging
   158  	IgnoreLogs []*regexp.Regexp
   159  
   160  	// Password protection
   161  	Credentials *Credentials
   162  
   163  	lrserver *livereload.Server
   164  }
   165  
   166  // WrapHandler wraps an httpctx.Handler in the paraphernalia needed by devd for
   167  // logging, latency, and so forth.
   168  func (dd *Devd) WrapHandler(log termlog.TermLog, next httpctx.Handler) http.Handler {
   169  	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   170  		r.URL.Scheme = dd.ServingScheme
   171  		revertOriginalHost(r)
   172  		timr := timer.Timer{}
   173  		sublog := log.Group()
   174  		defer func() {
   175  			timing := termlog.DefaultPalette.Timestamp.SprintFunc()("timing: ")
   176  			sublog.SayAs("timer", timing+timr.String())
   177  			sublog.Done()
   178  		}()
   179  		if matchStringAny(dd.IgnoreLogs, fmt.Sprintf("%s%s", r.URL.Host, r.RequestURI)) {
   180  			sublog.Quiet()
   181  		}
   182  		timr.RequestHeaders()
   183  		time.Sleep(time.Millisecond * time.Duration(dd.Latency))
   184  
   185  		dpath := r.RequestURI
   186  		if !strings.HasPrefix(dpath, "/") {
   187  			dpath = "/" + dpath
   188  		}
   189  		sublog.Say("%s %s", r.Method, dpath)
   190  		LogHeader(sublog, r.Header)
   191  		ctx := timr.NewContext(context.Background())
   192  		ctx = termlog.NewContext(ctx, sublog)
   193  		if dd.AddHeaders != nil {
   194  			for h, vals := range *dd.AddHeaders {
   195  				for _, v := range vals {
   196  					w.Header().Set(h, v)
   197  				}
   198  			}
   199  		}
   200  		if dd.Cors {
   201  			origin := r.Header.Get("Origin")
   202  			if origin == "" {
   203  				origin = "*"
   204  			}
   205  			w.Header().Set("Access-Control-Allow-Origin", origin)
   206  			requestHeaders := r.Header.Get("Access-Control-Request-Headers")
   207  			if requestHeaders != "" {
   208  				w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
   209  			}
   210  			requestMethod := r.Header.Get("Access-Control-Request-Method")
   211  			if requestMethod != "" {
   212  				w.Header().Set("Access-Control-Allow-Methods", requestMethod)
   213  			}
   214  		}
   215  		flusher, _ := w.(http.Flusher)
   216  		next.ServeHTTPContext(
   217  			ctx,
   218  			&ResponseLogWriter{Log: sublog, Resp: w, Flusher: flusher, Timer: &timr},
   219  			r,
   220  		)
   221  	})
   222  	return h
   223  }
   224  
   225  // HasLivereload tells us if livereload is enabled
   226  func (dd *Devd) HasLivereload() bool {
   227  	if dd.Livereload || dd.LivereloadRoutes || len(dd.WatchPaths) > 0 {
   228  		return true
   229  	}
   230  	return false
   231  }
   232  
   233  // AddRoutes adds route specifications to the server
   234  func (dd *Devd) AddRoutes(specs []string, notfound []string) error {
   235  	dd.Routes = make(RouteCollection)
   236  	for _, s := range specs {
   237  		err := dd.Routes.Add(s, notfound)
   238  		if err != nil {
   239  			return fmt.Errorf("Invalid route specification: %s", err)
   240  		}
   241  	}
   242  	return nil
   243  }
   244  
   245  // AddIgnores adds log ignore patterns to the server
   246  func (dd *Devd) AddIgnores(specs []string) error {
   247  	dd.IgnoreLogs = make([]*regexp.Regexp, 0, 0)
   248  	for _, expr := range specs {
   249  		v, err := regexp.Compile(expr)
   250  		if err != nil {
   251  			return fmt.Errorf("%s", err)
   252  		}
   253  		dd.IgnoreLogs = append(dd.IgnoreLogs, v)
   254  	}
   255  	return nil
   256  }
   257  
   258  // HandleNotFound handles pages not found. In particular, this handler is used
   259  // when we have no matching route for a request. This also means it's not
   260  // useful to inject the livereload paraphernalia here.
   261  func HandleNotFound(templates *template.Template) httpctx.Handler {
   262  	return httpctx.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
   263  		w.WriteHeader(http.StatusNotFound)
   264  		err := templates.Lookup("404.html").Execute(w, nil)
   265  		if err != nil {
   266  			logger := termlog.FromContext(ctx)
   267  			logger.Shout("Could not execute template: %s", err)
   268  		}
   269  	})
   270  }
   271  
   272  // Router constructs the main Devd router that serves all requests
   273  func (dd *Devd) Router(logger termlog.TermLog, templates *template.Template) (http.Handler, error) {
   274  	mux := http.NewServeMux()
   275  	hasGlobal := false
   276  
   277  	ci := inject.CopyInject{}
   278  	if dd.HasLivereload() {
   279  		ci = livereload.Injector
   280  	}
   281  
   282  	for match, route := range dd.Routes {
   283  		if match == "/" {
   284  			hasGlobal = true
   285  		}
   286  		handler := dd.WrapHandler(
   287  			logger,
   288  			route.Endpoint.Handler(route.Path, templates, ci),
   289  		)
   290  		mux.Handle(match, handler)
   291  	}
   292  	if dd.HasLivereload() {
   293  		lr := livereload.NewServer("livereload", logger)
   294  		mux.Handle(livereload.EndpointPath, lr)
   295  		mux.Handle(livereload.ScriptPath, http.HandlerFunc(lr.ServeScript))
   296  		seen := make(map[string]bool)
   297  		for _, route := range dd.Routes {
   298  			if _, ok := seen[route.Host]; route.Host != "" && ok == false {
   299  				mux.Handle(route.Host+livereload.EndpointPath, lr)
   300  				mux.Handle(
   301  					route.Host+livereload.ScriptPath,
   302  					http.HandlerFunc(lr.ServeScript),
   303  				)
   304  				seen[route.Host] = true
   305  			}
   306  		}
   307  		if dd.LivereloadRoutes {
   308  			err := WatchRoutes(dd.Routes, lr, dd.Excludes, logger)
   309  			if err != nil {
   310  				return nil, fmt.Errorf("Could not watch routes for livereload: %s", err)
   311  			}
   312  		}
   313  		if len(dd.WatchPaths) > 0 {
   314  			err := WatchPaths(dd.WatchPaths, dd.Excludes, lr, logger)
   315  			if err != nil {
   316  				return nil, fmt.Errorf("Could not watch path for livereload: %s", err)
   317  			}
   318  		}
   319  		dd.lrserver = lr
   320  	}
   321  	if !hasGlobal {
   322  		mux.Handle(
   323  			"/",
   324  			dd.WrapHandler(logger, HandleNotFound(templates)),
   325  		)
   326  	}
   327  	var h = http.Handler(mux)
   328  	if dd.Credentials != nil {
   329  		h = httpauth.SimpleBasicAuth(
   330  			dd.Credentials.username, dd.Credentials.password,
   331  		)(h)
   332  	}
   333  	return hostPortStrip(h), nil
   334  }
   335  
   336  // Serve starts the devd server. The callback is called with the serving URL
   337  // just before service starts.
   338  func (dd *Devd) Serve(address string, port int, certFile string, logger termlog.TermLog, callback func(string)) error {
   339  	templates, err := ricetemp.MakeTemplates(rice.MustFindBox("templates"))
   340  	if err != nil {
   341  		return fmt.Errorf("Error loading templates: %s", err)
   342  	}
   343  	mux, err := dd.Router(logger, templates)
   344  	if err != nil {
   345  		return err
   346  	}
   347  	var tlsConfig *tls.Config
   348  	var tlsEnabled bool
   349  	if certFile != "" {
   350  		tlsConfig, err = getTLSConfig(certFile)
   351  		if err != nil {
   352  			return fmt.Errorf("Could not load certs: %s", err)
   353  		}
   354  		tlsEnabled = true
   355  	}
   356  
   357  	var hl net.Listener
   358  	if port > 0 {
   359  		hl, err = net.Listen("tcp", fmt.Sprintf("%v:%d", address, port))
   360  	} else {
   361  		hl, err = pickPort(address, portLow, portHigh, tlsEnabled)
   362  	}
   363  	if err != nil {
   364  		return err
   365  	}
   366  
   367  	if tlsConfig != nil {
   368  		hl = tls.NewListener(hl, tlsConfig)
   369  	}
   370  
   371  	hl = slowdown.NewSlowListener(hl, dd.UpKbps*1024, dd.DownKbps*1024)
   372  	url := formatURL(tlsEnabled, address, hl.Addr().(*net.TCPAddr).Port)
   373  	logger.Say("Listening on %s (%s)", url, hl.Addr().String())
   374  	server := &http.Server{Addr: hl.Addr().String(), Handler: mux}
   375  	callback(url)
   376  
   377  	if dd.HasLivereload() {
   378  		c := make(chan os.Signal, 1)
   379  		signal.Notify(c, syscall.SIGHUP)
   380  		go func() {
   381  			for {
   382  				<-c
   383  				logger.Say("Received signal - reloading")
   384  				dd.lrserver.Reload([]string{"*"})
   385  			}
   386  		}()
   387  	}
   388  
   389  	err = server.Serve(hl)
   390  	logger.Shout("Server stopped: %v", err)
   391  	return nil
   392  }