github.hscsec.cn/aerogo/aero@v1.0.0/Application.go (about)

     1  package aero
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"mime"
     7  	"net"
     8  	"net/http"
     9  	"os"
    10  	"os/signal"
    11  	"sort"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"syscall"
    16  	"time"
    17  
    18  	"github.com/aerogo/csp"
    19  
    20  	"github.com/aerogo/http/client"
    21  	performance "github.com/aerogo/linter-performance"
    22  	"github.com/aerogo/session"
    23  	memstore "github.com/aerogo/session-store-memory"
    24  	"github.com/fatih/color"
    25  	"github.com/julienschmidt/httprouter"
    26  )
    27  
    28  // Application represents a single web service.
    29  type Application struct {
    30  	Config                *Configuration
    31  	Sessions              session.Manager
    32  	Security              ApplicationSecurity
    33  	Linters               []Linter
    34  	Router                *httprouter.Router
    35  	ContentSecurityPolicy *csp.ContentSecurityPolicy
    36  
    37  	servers        [2]*http.Server
    38  	serversMutex   sync.Mutex
    39  	routeTests     map[string][]string
    40  	start          time.Time
    41  	rewrite        func(*RewriteContext)
    42  	middleware     []Middleware
    43  	pushConditions []func(*Context) bool
    44  	onStart        []func()
    45  	onShutdown     []func()
    46  	onPush         []func(*Context)
    47  	stop           chan os.Signal
    48  
    49  	routes struct {
    50  		GET  []string
    51  		POST []string
    52  	}
    53  }
    54  
    55  // New creates a new application.
    56  func New() *Application {
    57  	app := new(Application)
    58  	app.start = time.Now()
    59  	app.routeTests = make(map[string][]string)
    60  	app.Router = httprouter.New()
    61  
    62  	// Default linters
    63  	app.Linters = []Linter{
    64  		performance.New(),
    65  	}
    66  
    67  	// Default CSP
    68  	app.ContentSecurityPolicy = csp.New()
    69  	app.ContentSecurityPolicy.SetMap(csp.Map{
    70  		"default-src":  "'none'",
    71  		"img-src":      "https:",
    72  		"media-src":    "https:",
    73  		"script-src":   "'self'",
    74  		"style-src":    "'self'",
    75  		"font-src":     "https:",
    76  		"manifest-src": "'self'",
    77  		"connect-src":  "https: wss:",
    78  		"worker-src":   "'self'",
    79  		"frame-src":    "https:",
    80  		"base-uri":     "'self'",
    81  		"form-action":  "'self'",
    82  	})
    83  
    84  	// Configuration
    85  	app.Config = new(Configuration)
    86  	app.Config.Reset()
    87  	app.Load()
    88  
    89  	// Default session store: Memory
    90  	app.Sessions.Store = memstore.New()
    91  
    92  	// Default style
    93  	// app.SetStyle("")
    94  
    95  	// Set mime type for WebP because Go standard library doesn't include it
    96  	mime.AddExtensionType(".webp", "image/webp")
    97  
    98  	// Receive signals
    99  	app.stop = make(chan os.Signal, 1)
   100  	signal.Notify(app.stop, os.Interrupt, syscall.SIGTERM)
   101  
   102  	return app
   103  }
   104  
   105  // Get registers your function to be called when a certain GET path has been requested.
   106  func (app *Application) Get(path string, handle Handle) {
   107  	app.routes.GET = append(app.routes.GET, path)
   108  	app.Router.GET(path, app.createRouteHandler(path, handle))
   109  }
   110  
   111  // Post registers your function to be called when a certain POST path has been requested.
   112  func (app *Application) Post(path string, handle Handle) {
   113  	app.routes.POST = append(app.routes.POST, path)
   114  	app.Router.POST(path, app.createRouteHandler(path, handle))
   115  }
   116  
   117  // createRouteHandler creates a handler function for httprouter.
   118  func (app *Application) createRouteHandler(path string, handle Handle) httprouter.Handle {
   119  	return func(response http.ResponseWriter, request *http.Request, params httprouter.Params) {
   120  		// Create context.
   121  		ctx := Context{
   122  			App:        app,
   123  			StatusCode: http.StatusOK,
   124  			request:    request,
   125  			response:   response,
   126  			params:     params,
   127  		}
   128  
   129  		// The last part of the call chain will send the actual response.
   130  		lastPartOfCallChain := func() {
   131  			data := handle(&ctx)
   132  			ctx.respond(data)
   133  		}
   134  
   135  		// Declare the type of generateNext so that we can define it recursively in the next part.
   136  		var generateNext func(index int) func()
   137  
   138  		// Create a function that returns a bound function next()
   139  		// which can be used as the 2nd parameter in the call chain.
   140  		generateNext = func(index int) func() {
   141  			if index == len(app.middleware) {
   142  				return lastPartOfCallChain
   143  			}
   144  
   145  			return func() {
   146  				app.middleware[index](&ctx, generateNext(index+1))
   147  			}
   148  		}
   149  
   150  		generateNext(0)()
   151  	}
   152  }
   153  
   154  // Run starts your application.
   155  func (app *Application) Run() {
   156  	app.ListenAndServe()
   157  
   158  	for _, callback := range app.onStart {
   159  		callback()
   160  	}
   161  
   162  	app.TestManifest()
   163  	app.TestRoutes()
   164  	app.Wait()
   165  	app.Shutdown()
   166  }
   167  
   168  // Use adds middleware to your middleware chain.
   169  func (app *Application) Use(middlewares ...Middleware) {
   170  	app.middleware = append(app.middleware, middlewares...)
   171  }
   172  
   173  // Load loads the application configuration from config.json.
   174  func (app *Application) Load() {
   175  	config, err := LoadConfig("config.json")
   176  
   177  	if err != nil {
   178  		// Ignore missing config file, we can perfectly run without one
   179  		return
   180  	}
   181  
   182  	app.Config = config
   183  }
   184  
   185  // ListenAndServe starts the server.
   186  // It guarantees that a TCP listener is listening on the ports defined in the config
   187  // when the function returns.
   188  func (app *Application) ListenAndServe() {
   189  	if app.Security.Key != "" && app.Security.Certificate != "" {
   190  		listener := app.listen(":" + strconv.Itoa(app.Config.Ports.HTTPS))
   191  
   192  		go func() {
   193  			app.serveHTTPS(listener)
   194  		}()
   195  
   196  		fmt.Println("Server running on:", color.GreenString("https://localhost:"+strconv.Itoa(app.Config.Ports.HTTPS)))
   197  	} else {
   198  		fmt.Println("Server running on:", color.GreenString("http://localhost:"+strconv.Itoa(app.Config.Ports.HTTP)))
   199  	}
   200  
   201  	listener := app.listen(":" + strconv.Itoa(app.Config.Ports.HTTP))
   202  
   203  	go func() {
   204  		app.serveHTTP(listener)
   205  	}()
   206  }
   207  
   208  // Wait will make the process wait until it is killed.
   209  func (app *Application) Wait() {
   210  	<-app.stop
   211  }
   212  
   213  // Shutdown will gracefully shut down all servers.
   214  func (app *Application) Shutdown() {
   215  	app.serversMutex.Lock()
   216  	defer app.serversMutex.Unlock()
   217  
   218  	shutdown(app.servers[0])
   219  	shutdown(app.servers[1])
   220  
   221  	for _, callback := range app.onShutdown {
   222  		callback()
   223  	}
   224  }
   225  
   226  // shutdown will gracefully shut down the server.
   227  func shutdown(server *http.Server) {
   228  	if server == nil {
   229  		return
   230  	}
   231  
   232  	// Add a timeout to the server shutdown
   233  	ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
   234  	defer cancel()
   235  
   236  	// Shut down server
   237  	err := server.Shutdown(ctx)
   238  
   239  	if err != nil {
   240  		fmt.Println(err)
   241  	}
   242  }
   243  
   244  // OnStart registers a callback to be executed on server start.
   245  func (app *Application) OnStart(callback func()) {
   246  	app.onStart = append(app.onStart, callback)
   247  }
   248  
   249  // OnEnd registers a callback to be executed on server shutdown.
   250  func (app *Application) OnEnd(callback func()) {
   251  	app.onShutdown = append(app.onShutdown, callback)
   252  }
   253  
   254  // OnPush registers a callback to be executed when an HTTP/2 push happens.
   255  func (app *Application) OnPush(callback func(*Context)) {
   256  	app.onPush = append(app.onPush, callback)
   257  }
   258  
   259  // AddPushCondition registers a callback to be executed when an HTTP/2 push happens.
   260  func (app *Application) AddPushCondition(test func(*Context) bool) {
   261  	app.pushConditions = append(app.pushConditions, test)
   262  }
   263  
   264  // Rewrite sets the URL rewrite function.
   265  func (app *Application) Rewrite(rewrite func(*RewriteContext)) {
   266  	app.rewrite = rewrite
   267  }
   268  
   269  // StartTime returns the time the application started.
   270  func (app *Application) StartTime() time.Time {
   271  	return app.start
   272  }
   273  
   274  // Handler returns the request handler used by the application.
   275  func (app *Application) Handler() http.Handler {
   276  	router := app.Router
   277  	rewrite := app.rewrite
   278  
   279  	if rewrite != nil {
   280  		return &rewriteHandler{
   281  			rewrite: rewrite,
   282  			router:  router,
   283  		}
   284  	}
   285  
   286  	return router
   287  }
   288  
   289  // createServer creates an http server instance.
   290  func (app *Application) createServer() *http.Server {
   291  	return &http.Server{
   292  		Handler:           app.Handler(),
   293  		ReadHeaderTimeout: 5 * time.Second,
   294  		WriteTimeout:      180 * time.Second,
   295  		IdleTimeout:       120 * time.Second,
   296  		TLSConfig:         createTLSConfig(),
   297  	}
   298  }
   299  
   300  // listen returns a Listener for the given address.
   301  func (app *Application) listen(address string) Listener {
   302  	listener, err := net.Listen("tcp", address)
   303  
   304  	if err != nil {
   305  		panic(err)
   306  	}
   307  
   308  	return Listener{listener.(*net.TCPListener)}
   309  }
   310  
   311  // serveHTTP serves requests from the given listener.
   312  func (app *Application) serveHTTP(listener Listener) {
   313  	server := app.createServer()
   314  
   315  	app.serversMutex.Lock()
   316  	app.servers[0] = server
   317  	app.serversMutex.Unlock()
   318  
   319  	// This will block the calling goroutine until the server shuts down.
   320  	err := server.Serve(listener)
   321  
   322  	if err != nil && !strings.Contains(err.Error(), "closed") {
   323  		panic(err)
   324  	}
   325  }
   326  
   327  // serveHTTPS serves requests from the given listener.
   328  func (app *Application) serveHTTPS(listener Listener) {
   329  	server := app.createServer()
   330  
   331  	app.serversMutex.Lock()
   332  	app.servers[1] = server
   333  	app.serversMutex.Unlock()
   334  
   335  	// This will block the calling goroutine until the server shuts down.
   336  	err := server.ServeTLS(listener, app.Security.Certificate, app.Security.Key)
   337  
   338  	if err != nil && !strings.Contains(err.Error(), "closed") {
   339  		panic(err)
   340  	}
   341  }
   342  
   343  // Test tests the given URI paths when the application starts.
   344  func (app *Application) Test(route string, paths []string) {
   345  	app.routeTests[route] = paths
   346  }
   347  
   348  // TestManifest tests your application's manifest.
   349  func (app *Application) TestManifest() {
   350  	manifest := app.Config.Manifest
   351  
   352  	// Warn about short name length (Google Lighthouse)
   353  	// https://developer.chrome.com/apps/manifest/name#short_name
   354  	if len(manifest.ShortName) >= 12 {
   355  		color.Yellow("The short name of your application should have less than 12 characters")
   356  	}
   357  }
   358  
   359  // TestRoutes tests your application's routes.
   360  func (app *Application) TestRoutes() {
   361  	fmt.Println(strings.Repeat("-", 80))
   362  
   363  	go func() {
   364  		sort.Strings(app.routes.GET)
   365  
   366  		for _, route := range app.routes.GET {
   367  			// Skip ajax routes
   368  			if strings.HasPrefix(route, "/_") {
   369  				continue
   370  			}
   371  
   372  			// Check if the user defined test routes for the given route
   373  			testRoutes, exists := app.routeTests[route]
   374  
   375  			if exists {
   376  				for _, testRoute := range testRoutes {
   377  					app.TestRoute(route, testRoute)
   378  				}
   379  
   380  				continue
   381  			}
   382  
   383  			// Skip routes with parameters and display a warning to indicate it needs a test route
   384  			if strings.Contains(route, ":") {
   385  				color.Yellow(route)
   386  				continue
   387  			}
   388  
   389  			// Test the static route without parameters
   390  			app.TestRoute(route, route)
   391  		}
   392  
   393  		// json, _ := Post("https://html5.validator.nu/?out=json").Header("Content-Type", "text/html; charset=utf-8").Header("Content-Encoding", "gzip").Body(body).Send()
   394  		// fmt.Println(json)
   395  	}()
   396  }
   397  
   398  // TestRoute tests the given route.
   399  func (app *Application) TestRoute(route string, uri string) {
   400  	for _, linter := range app.Linters {
   401  		linter.Begin(route, uri)
   402  	}
   403  
   404  	response, _ := client.Get("http://localhost:" + strconv.Itoa(app.Config.Ports.HTTP) + uri).End()
   405  
   406  	for _, linter := range app.Linters {
   407  		linter.End(route, uri, response)
   408  	}
   409  }