github.com/blend/go-sdk@v1.20220411.3/web/app.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package web
     9  
    10  import (
    11  	"context"
    12  	"crypto/tls"
    13  	"net"
    14  	"net/http"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/blend/go-sdk/async"
    19  	"github.com/blend/go-sdk/ex"
    20  	"github.com/blend/go-sdk/logger"
    21  	"github.com/blend/go-sdk/proxyprotocol"
    22  	"github.com/blend/go-sdk/webutil"
    23  )
    24  
    25  // MustNew creates a new app and panics if there is an error.
    26  func MustNew(options ...Option) *App {
    27  	app, err := New(options...)
    28  	if err != nil {
    29  		panic(err)
    30  	}
    31  	return app
    32  }
    33  
    34  // New returns a new web app.
    35  func New(options ...Option) (*App, error) {
    36  	views, err := NewViewCache()
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	auth, err := NewAuthManager()
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  	a := App{
    45  		RouteTree:       new(RouteTree),
    46  		Auth:            auth,
    47  		BaseContext:     func(_ net.Listener) context.Context { return context.Background() },
    48  		BaseHeaders:     BaseHeaders(),
    49  		BaseState:       new(SyncState),
    50  		DefaultProvider: views,
    51  		Latch:           async.NewLatch(),
    52  		Server:          new(http.Server),
    53  		Statics:         map[string]*StaticFileServer{},
    54  		Views:           views,
    55  	}
    56  
    57  	for _, option := range options {
    58  		if err = option(&a); err != nil {
    59  			return nil, err
    60  		}
    61  	}
    62  	return &a, nil
    63  }
    64  
    65  // App is the server for the app.
    66  type App struct {
    67  	*async.Latch
    68  	*RouteTree
    69  
    70  	Config Config
    71  
    72  	Auth        AuthManager
    73  	BaseContext func(net.Listener) context.Context
    74  
    75  	BaseHeaders    http.Header
    76  	BaseMiddleware []Middleware
    77  	BaseState      State
    78  
    79  	Log    logger.Log
    80  	Tracer Tracer
    81  
    82  	TLSConfig *tls.Config
    83  	Server    *http.Server
    84  	Listener  net.Listener
    85  
    86  	Statics map[string]*StaticFileServer
    87  
    88  	DefaultProvider ResultProvider
    89  	Views           *ViewCache
    90  
    91  	PanicAction PanicAction
    92  }
    93  
    94  // Background returns a base context.
    95  func (a *App) Background() context.Context {
    96  	if a.BaseContext != nil {
    97  		return a.BaseContext(a.Listener)
    98  	}
    99  	return context.Background()
   100  }
   101  
   102  // --------------------------------------------------------------------------------
   103  // Lifecycle
   104  // --------------------------------------------------------------------------------
   105  
   106  // Start starts the server and binds to the given address.
   107  func (a *App) Start() (err error) {
   108  	if !a.Latch.CanStart() {
   109  		return ex.New(async.ErrCannotStart)
   110  	}
   111  	for _, opt := range a.httpServerOptions() {
   112  		if err = opt(a.Server); err != nil {
   113  			return err
   114  		}
   115  	}
   116  
   117  	err = a.StartupTasks()
   118  	if err != nil {
   119  		return
   120  	}
   121  
   122  	var shutdownErr error
   123  	if a.Listener == nil {
   124  		serverProtocol := "http"
   125  		if a.Server.TLSConfig != nil {
   126  			serverProtocol = "https (tls)"
   127  		}
   128  		if a.Server.Addr == "" {
   129  			a.Server.Addr = a.Config.BindAddrOrDefault()
   130  		}
   131  
   132  		var rawListener net.Listener
   133  		rawListener, err = net.Listen("tcp", a.Server.Addr)
   134  		if err != nil {
   135  			err = ex.New(err)
   136  			return
   137  		}
   138  		typedListener, ok := rawListener.(*net.TCPListener)
   139  		if !ok {
   140  			err = ex.New("listener returned was not a net.TCPListener")
   141  			return
   142  		}
   143  		a.Listener = webutil.TCPKeepAliveListener{
   144  			TCPListener:     typedListener,
   145  			KeepAlive:       a.Config.KeepAliveOrDefault(),
   146  			KeepAlivePeriod: a.Config.KeepAlivePeriodOrDefault(),
   147  		}
   148  
   149  		if a.Config.UseProxyProtocol {
   150  			logger.MaybeInfofContext(a.Background(), a.Log, "%s using proxy protocol", serverProtocol)
   151  			a.Listener = &proxyprotocol.Listener{Listener: a.Listener}
   152  		}
   153  
   154  		if a.Server.TLSConfig != nil {
   155  			logger.MaybeInfofContext(a.Background(), a.Log, "%s using tls", serverProtocol)
   156  			a.Listener = tls.NewListener(a.Listener, a.Server.TLSConfig)
   157  		}
   158  		logger.MaybeInfofContext(a.Background(), a.Log, "%s server started, listening on %s", serverProtocol, a.Server.Addr)
   159  	} else {
   160  		logger.MaybeInfofContext(a.Background(), a.Log, "http server started, using custom listener")
   161  	}
   162  
   163  	a.Started()
   164  	shutdownErr = a.Server.Serve(a.Listener)
   165  	if shutdownErr != nil && shutdownErr != http.ErrServerClosed {
   166  		err = ex.New(shutdownErr)
   167  	}
   168  	logger.MaybeInfofContext(a.Background(), a.Log, "server stopped serving")
   169  	a.Stopped()
   170  	return
   171  }
   172  
   173  // Stop stops the server.
   174  func (a *App) Stop() error {
   175  	if !a.CanStop() {
   176  		return ex.New(async.ErrCannotStop)
   177  	}
   178  	a.Stopping()
   179  
   180  	ctx := a.Background()
   181  	var cancel context.CancelFunc
   182  	if gracePeriod := a.Config.ShutdownGracePeriodOrDefault(); gracePeriod > 0 {
   183  		logger.MaybeInfofContext(ctx, a.Log, "server shutdown grace period: %v", gracePeriod)
   184  		ctx, cancel = context.WithTimeout(ctx, gracePeriod)
   185  		defer cancel()
   186  	}
   187  	logger.MaybeInfofContext(ctx, a.Log, "server keep alives disabled")
   188  	a.Server.SetKeepAlivesEnabled(false)
   189  	logger.MaybeInfofContext(ctx, a.Log, "server shutting down")
   190  	if err := a.Server.Shutdown(ctx); err != nil {
   191  		if err == context.DeadlineExceeded {
   192  			logger.MaybeWarningfContext(ctx, a.Log, "server shutdown grace period exceeded, connections forcibly closed")
   193  		} else {
   194  			return ex.New(err)
   195  		}
   196  	}
   197  	logger.MaybeInfofContext(a.Background(), a.Log, "server shutdown complete")
   198  	return nil
   199  }
   200  
   201  // --------------------------------------------------------------------------------
   202  // Register Controllers
   203  // --------------------------------------------------------------------------------
   204  
   205  // Register registers controllers with the app's router.
   206  func (a *App) Register(controllers ...Controller) {
   207  	for _, c := range controllers {
   208  		c.Register(a)
   209  	}
   210  }
   211  
   212  // --------------------------------------------------------------------------------
   213  // Static Result Methods
   214  // --------------------------------------------------------------------------------
   215  
   216  // ServeStatic serves files from the given file system root(s)..
   217  // If the path does not end with "/*filepath" that suffix will be added for you internally.
   218  // For example if root is "/etc" and *filepath is "passwd", the local file
   219  // "/etc/passwd" would be served.
   220  func (a *App) ServeStatic(route string, searchPaths []string, middleware ...Middleware) {
   221  	var searchPathFS []http.FileSystem
   222  	for _, searchPath := range searchPaths {
   223  		searchPathFS = append(searchPathFS, http.Dir(searchPath))
   224  	}
   225  	sfs := NewStaticFileServer(
   226  		OptStaticFileServerSearchPaths(searchPathFS...),
   227  		OptStaticFileServerCacheDisabled(true),
   228  	)
   229  	mountedRoute := a.formatStaticMountRoute(route)
   230  	a.Statics[mountedRoute] = sfs
   231  	a.Method(webutil.MethodGet, mountedRoute, sfs.Action, middleware...)
   232  }
   233  
   234  // ServeStaticCached serves files from the given file system root(s).
   235  // If the path does not end with "/*filepath" that suffix will be added for you internally.
   236  func (a *App) ServeStaticCached(route string, searchPaths []string, middleware ...Middleware) {
   237  	var searchPathFS []http.FileSystem
   238  	for _, searchPath := range searchPaths {
   239  		searchPathFS = append(searchPathFS, http.Dir(searchPath))
   240  	}
   241  	sfs := NewStaticFileServer(
   242  		OptStaticFileServerSearchPaths(searchPathFS...),
   243  	)
   244  	mountedRoute := a.formatStaticMountRoute(route)
   245  	a.Statics[mountedRoute] = sfs
   246  	a.Method(webutil.MethodGet, mountedRoute, sfs.Action, middleware...)
   247  }
   248  
   249  // SetStaticRewriteRule adds a rewrite rule for a specific statically served path.
   250  // It mutates the path for the incoming static file request to the fileserver according to the action.
   251  func (a *App) SetStaticRewriteRule(route, match string, action RewriteAction) error {
   252  	mountedRoute := a.formatStaticMountRoute(route)
   253  	if static, hasRoute := a.Statics[mountedRoute]; hasRoute {
   254  		return static.AddRewriteRule(match, action)
   255  	}
   256  	return ex.New("no static fileserver mounted at route", ex.OptMessagef("route: %s", route))
   257  }
   258  
   259  // SetStaticHeader adds a header for the given static path.
   260  // These headers are automatically added to any result that the static path fileserver sends.
   261  func (a *App) SetStaticHeader(route, key, value string) error {
   262  	mountedRoute := a.formatStaticMountRoute(route)
   263  	if static, hasRoute := a.Statics[mountedRoute]; hasRoute {
   264  		static.AddHeader(key, value)
   265  		return nil
   266  	}
   267  	return ex.New("no static fileserver mounted at route", ex.OptMessagef("route: %s", mountedRoute))
   268  }
   269  
   270  // --------------------------------------------------------------------------------
   271  // Route Registration / HTTP Methods
   272  // --------------------------------------------------------------------------------
   273  
   274  // GET registers a GET request route handler with the given middleware.
   275  func (a *App) GET(path string, action Action, middleware ...Middleware) {
   276  	a.Method(http.MethodGet, path, action, middleware...)
   277  }
   278  
   279  // OPTIONS registers a OPTIONS request route handler the given middleware.
   280  func (a *App) OPTIONS(path string, action Action, middleware ...Middleware) {
   281  	a.Method(http.MethodOptions, path, action, middleware...)
   282  }
   283  
   284  // HEAD registers a HEAD request route handler with the given middleware.
   285  func (a *App) HEAD(path string, action Action, middleware ...Middleware) {
   286  	a.Method(http.MethodHead, path, action, middleware...)
   287  }
   288  
   289  // PUT registers a PUT request route handler with the given middleware.
   290  func (a *App) PUT(path string, action Action, middleware ...Middleware) {
   291  	a.Method(http.MethodPut, path, action, middleware...)
   292  }
   293  
   294  // PATCH registers a PATCH request route handler with the given middleware.
   295  func (a *App) PATCH(path string, action Action, middleware ...Middleware) {
   296  	a.Method(http.MethodPatch, path, action, middleware...)
   297  }
   298  
   299  // POST registers a POST request route handler with the given middleware.
   300  func (a *App) POST(path string, action Action, middleware ...Middleware) {
   301  	a.Method(http.MethodPost, path, action, middleware...)
   302  }
   303  
   304  // DELETE registers a DELETE request route handler with the given middleware.
   305  func (a *App) DELETE(path string, action Action, middleware ...Middleware) {
   306  	a.Method(http.MethodDelete, path, action, middleware...)
   307  }
   308  
   309  // Method registers an action for a given method and path with the given middleware.
   310  func (a *App) Method(method string, path string, action Action, middleware ...Middleware) {
   311  	a.RouteTree.Handle(method, path, a.RenderAction(NestMiddleware(action, append(middleware, a.BaseMiddleware...)...)))
   312  }
   313  
   314  // MethodBare registers an action for a given method and path with the given middleware that omits logging and tracing.
   315  func (a *App) MethodBare(method string, path string, action Action, middleware ...Middleware) {
   316  	a.RouteTree.Handle(method, path, a.RenderActionBare(NestMiddleware(action, append(middleware, a.BaseMiddleware...)...)))
   317  }
   318  
   319  // Lookup finds the route data for a given method and path.
   320  func (a *App) Lookup(method, path string) (route *Route, params RouteParameters, skipSlashRedirect bool) {
   321  	if root := a.RouteTree.Routes[method]; root != nil {
   322  		route, params, skipSlashRedirect = root.getValue(path)
   323  		return
   324  	}
   325  	return
   326  }
   327  
   328  // --------------------------------------------------------------------------------
   329  // Request Pipeline
   330  // --------------------------------------------------------------------------------
   331  
   332  // ServeHTTP makes the router implement the http.Handler interface.
   333  func (a *App) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   334  	if !a.Config.DisablePanicRecovery {
   335  		defer a.recover(w, req)
   336  	}
   337  	// load the request start time onto the request.
   338  	req = req.WithContext(WithRequestStarted(req.Context(), time.Now().UTC()))
   339  	a.RouteTree.ServeHTTP(w, req)
   340  }
   341  
   342  // RenderAction is the translation step from Action to Handler.
   343  func (a *App) RenderAction(action Action) Handler {
   344  	return func(w http.ResponseWriter, r *http.Request, route *Route, p RouteParameters) {
   345  		ctx := NewCtx(webutil.NewStatusResponseWriter(w), r, a.ctxOptions(r.Context(), route, p)...)
   346  		defer ctx.Close()
   347  		defer a.logRequest(ctx)
   348  
   349  		var err error
   350  		if a.Tracer != nil {
   351  			tf := ctx.Tracer.Start(ctx)
   352  			defer func() {
   353  				tf.Finish(ctx, err)
   354  			}()
   355  		}
   356  
   357  		for key, value := range a.BaseHeaders {
   358  			ctx.Response.Header()[key] = value
   359  		}
   360  
   361  		if result := action(ctx); result != nil {
   362  			if typed, ok := result.(ResultPreRender); ok {
   363  				if errPreRender := typed.PreRender(ctx); errPreRender != nil {
   364  					a.maybeLogFatal(ctx.Context(), errPreRender, ctx.Request)
   365  					err = ex.New(errPreRender, ex.OptInner(err))
   366  				}
   367  			}
   368  			if errRender := result.Render(ctx); errRender != nil {
   369  				a.maybeLogFatal(ctx.Context(), errRender, ctx.Request)
   370  				err = ex.New(errRender, ex.OptInner(err))
   371  			}
   372  			if typed, ok := result.(ResultPostRender); ok {
   373  				if errPostRender := typed.PostRender(ctx); errPostRender != nil {
   374  					a.maybeLogFatal(ctx.Context(), errPostRender, ctx.Request)
   375  					err = ex.New(errPostRender, ex.OptInner(err))
   376  				}
   377  			}
   378  		}
   379  	}
   380  }
   381  
   382  // RenderActionBare is the translation step from Action to Handler that omits logging.
   383  func (a *App) RenderActionBare(action Action) Handler {
   384  	return func(w http.ResponseWriter, r *http.Request, route *Route, p RouteParameters) {
   385  		ctx := NewCtx(webutil.NewStatusResponseWriter(w), r, a.ctxOptions(r.Context(), route, p)...)
   386  		defer ctx.Close()
   387  
   388  		for key, value := range a.BaseHeaders {
   389  			ctx.Response.Header()[key] = value
   390  		}
   391  
   392  		if result := action(ctx); result != nil {
   393  			if typed, ok := result.(ResultPreRender); ok {
   394  				if errPreRender := typed.PreRender(ctx); errPreRender != nil {
   395  					a.maybeLogFatal(ctx.Context(), errPreRender, ctx.Request)
   396  				}
   397  			}
   398  			if err := result.Render(ctx); err != nil {
   399  				a.maybeLogFatal(ctx.Context(), err, ctx.Request)
   400  			}
   401  			if typed, ok := result.(ResultPostRender); ok {
   402  				if errPostRender := typed.PostRender(ctx); errPostRender != nil {
   403  					a.maybeLogFatal(ctx.Context(), errPostRender, ctx.Request)
   404  				}
   405  			}
   406  		}
   407  	}
   408  }
   409  
   410  //
   411  // startup helpers
   412  //
   413  
   414  // StartupTasks runs common startup tasks.
   415  // These tasks include anything outside setting up the underlying server itself.
   416  // Right now, this is limited to initializing the view cache if relevant.
   417  func (a *App) StartupTasks() (err error) {
   418  	if err = a.Views.Initialize(); err != nil {
   419  		return
   420  	}
   421  	return nil
   422  }
   423  
   424  //
   425  // internal helpers
   426  //
   427  
   428  func (a *App) formatStaticMountRoute(route string) string {
   429  	mountedRoute := route
   430  	if !strings.HasSuffix(mountedRoute, "*"+RouteTokenFilepath) {
   431  		if strings.HasSuffix(mountedRoute, "/") {
   432  			mountedRoute = mountedRoute + "*" + RouteTokenFilepath
   433  		} else {
   434  			mountedRoute = mountedRoute + "/*" + RouteTokenFilepath
   435  		}
   436  	}
   437  	return mountedRoute
   438  }
   439  
   440  func (a *App) httpServerOptions() []webutil.HTTPServerOption {
   441  	return []webutil.HTTPServerOption{
   442  		webutil.OptHTTPServerHandler(a),
   443  		webutil.OptHTTPServerTLSConfig(a.TLSConfig),
   444  		webutil.OptHTTPServerAddr(a.Config.BindAddrOrDefault()),
   445  		webutil.OptHTTPServerMaxHeaderBytes(a.Config.MaxHeaderBytesOrDefault()),
   446  		webutil.OptHTTPServerReadTimeout(a.Config.ReadTimeoutOrDefault()),
   447  		webutil.OptHTTPServerReadHeaderTimeout(a.Config.ReadHeaderTimeoutOrDefault()),
   448  		webutil.OptHTTPServerWriteTimeout(a.Config.WriteTimeoutOrDefault()),
   449  		webutil.OptHTTPServerIdleTimeout(a.Config.IdleTimeoutOrDefault()),
   450  		webutil.OptHTTPServerBaseContext(a.BaseContext),
   451  	}
   452  }
   453  
   454  func (a *App) ctxOptions(ctx context.Context, route *Route, p RouteParameters) []CtxOption {
   455  	return []CtxOption{
   456  		OptCtxApp(a),
   457  		OptCtxAuth(a.Auth),
   458  		OptCtxDefaultProvider(a.DefaultProvider),
   459  		OptCtxViews(a.Views),
   460  		OptCtxRoute(route),
   461  		OptCtxRouteParams(p),
   462  		OptCtxState(a.BaseState.Copy()),
   463  		OptCtxLog(a.Log),
   464  		OptCtxTracer(a.Tracer),
   465  		OptCtxRequestStarted(GetRequestStarted(ctx)),
   466  	}
   467  }
   468  
   469  func (a *App) recover(w http.ResponseWriter, req *http.Request) {
   470  	if rcv := recover(); rcv != nil {
   471  		err := ex.New(rcv)
   472  		a.maybeLogFatal(req.Context(), err, req)
   473  		if a.PanicAction != nil {
   474  			a.RenderAction(func(ctx *Ctx) Result {
   475  				return a.PanicAction(ctx, err)
   476  			})(w, req, nil, nil)
   477  			return
   478  		}
   479  		http.Error(w, "an internal server error occurred", http.StatusInternalServerError)
   480  		return
   481  	}
   482  }
   483  
   484  func (a *App) maybeLogFatal(ctx context.Context, err error, req *http.Request) {
   485  	if !logger.IsLoggerSet(a.Log) || err == nil {
   486  		return
   487  	}
   488  	a.Log.TriggerContext(
   489  		ctx,
   490  		logger.NewErrorEvent(
   491  			logger.Fatal,
   492  			err,
   493  			logger.OptErrorEventState(req),
   494  		),
   495  	)
   496  }
   497  
   498  func (a *App) logRequest(r *Ctx) {
   499  	requestEvent := webutil.NewHTTPRequestEvent(r.Request.Clone(r.Context()),
   500  		webutil.OptHTTPRequestStatusCode(r.Response.StatusCode()),
   501  		webutil.OptHTTPRequestContentLength(r.Response.ContentLength()),
   502  		webutil.OptHTTPRequestHeader(r.Response.Header().Clone()),
   503  		webutil.OptHTTPRequestElapsed(r.Elapsed()),
   504  	)
   505  	if r.Route != nil {
   506  		requestEvent.Route = r.Route.String()
   507  	}
   508  	if requestEvent.Header != nil {
   509  		requestEvent.ContentType = requestEvent.Header.Get(webutil.HeaderContentType)
   510  		requestEvent.ContentEncoding = requestEvent.Header.Get(webutil.HeaderContentEncoding)
   511  	}
   512  	a.maybeLogTrigger(r.Context(), r.Log, requestEvent)
   513  }
   514  
   515  func (a *App) maybeLogTrigger(ctx context.Context, log logger.Log, e logger.Event) {
   516  	if !logger.IsLoggerSet(log) || e == nil {
   517  		return
   518  	}
   519  	log.TriggerContext(ctx, e)
   520  }