github.com/System-Glitch/goyave/v3@v3.6.1-0.20210226143142-ac2fe42ee80e/goyave.go (about)

     1  package goyave
     2  
     3  import (
     4  	"context"
     5  	"log"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"os/signal"
    10  	"strconv"
    11  	"sync"
    12  	"syscall"
    13  	"time"
    14  
    15  	"github.com/System-Glitch/goyave/v3/config"
    16  	"github.com/System-Glitch/goyave/v3/database"
    17  	"github.com/System-Glitch/goyave/v3/lang"
    18  )
    19  
    20  var (
    21  	server             *http.Server
    22  	redirectServer     *http.Server
    23  	router             *Router
    24  	maintenanceHandler http.Handler
    25  	sigChannel         chan os.Signal
    26  	tlsStopChannel     chan struct{} = make(chan struct{}, 1)
    27  	stopChannel        chan struct{} = make(chan struct{}, 1)
    28  	hookChannel        chan struct{} = make(chan struct{}, 1)
    29  
    30  	// Critical config entries (cached for better performance)
    31  	protocol        string
    32  	maxPayloadSize  int64
    33  	defaultLanguage string
    34  
    35  	startupHooks       []func()
    36  	shutdownHooks      []func()
    37  	ready              bool = false
    38  	maintenanceEnabled bool = false
    39  	mutex                   = &sync.RWMutex{}
    40  	once               sync.Once
    41  
    42  	// Logger the logger for default output
    43  	// Writes to stdout by default.
    44  	Logger *log.Logger = log.New(os.Stdout, "", log.LstdFlags)
    45  
    46  	// AccessLogger the logger for access. This logger
    47  	// is used by the logging middleware.
    48  	// Writes to stdout by default.
    49  	AccessLogger *log.Logger = log.New(os.Stdout, "", 0)
    50  
    51  	// ErrLogger the logger in which errors and stacktraces are written.
    52  	// Writes to stderr by default.
    53  	ErrLogger *log.Logger = log.New(os.Stderr, "", log.LstdFlags)
    54  )
    55  
    56  const (
    57  	// ExitInvalidConfig the exit code returned when the config
    58  	// validation doesn't pass.
    59  	ExitInvalidConfig = 3
    60  
    61  	// ExitNetworkError the exit code returned when an error
    62  	// occurs when opening the network listener
    63  	ExitNetworkError = 4
    64  
    65  	// ExitHTTPError the exit code returned when an error
    66  	// occurs in the HTTP server (port already in use for example)
    67  	ExitHTTPError = 5
    68  )
    69  
    70  // Error wrapper for errors directely related to the server itself.
    71  // Contains an exit code and the original error.
    72  type Error struct {
    73  	ExitCode int
    74  	Err      error
    75  }
    76  
    77  func (e *Error) Error() string {
    78  	return e.Err.Error()
    79  }
    80  
    81  // IsReady returns true if the server has finished initializing and
    82  // is ready to serve incoming requests.
    83  func IsReady() bool {
    84  	mutex.RLock()
    85  	defer mutex.RUnlock()
    86  	return ready
    87  }
    88  
    89  // RegisterStartupHook to execute some code once the server is ready and running.
    90  func RegisterStartupHook(hook func()) {
    91  	mutex.Lock()
    92  	startupHooks = append(startupHooks, hook)
    93  	mutex.Unlock()
    94  }
    95  
    96  // ClearStartupHooks removes all startup hooks.
    97  func ClearStartupHooks() {
    98  	mutex.Lock()
    99  	startupHooks = []func(){}
   100  	mutex.Unlock()
   101  }
   102  
   103  // RegisterShutdownHook to execute some code after the server stopped.
   104  // Shutdown hooks are executed before goyave.Start() returns.
   105  func RegisterShutdownHook(hook func()) {
   106  	mutex.Lock()
   107  	shutdownHooks = append(shutdownHooks, hook)
   108  	mutex.Unlock()
   109  }
   110  
   111  // ClearShutdownHooks removes all shutdown hooks.
   112  func ClearShutdownHooks() {
   113  	mutex.Lock()
   114  	shutdownHooks = []func(){}
   115  	mutex.Unlock()
   116  }
   117  
   118  // Start starts the web server.
   119  // The routeRegistrer parameter is a function aimed at registering all your routes and middleware.
   120  //  import (
   121  //      "github.com/System-Glitch/goyave/v3"
   122  //      "github.com/username/projectname/route"
   123  //  )
   124  //
   125  //  func main() {
   126  //      if err := goyave.Start(route.Register); err != nil {
   127  //          os.Exit(err.(*goyave.Error).ExitCode)
   128  //      }
   129  //  }
   130  //
   131  // Errors returned can be safely type-asserted to "*goyave.Error".
   132  // Panics if the server is already running.
   133  func Start(routeRegistrer func(*Router)) error {
   134  	if IsReady() {
   135  		ErrLogger.Panicf("Server is already running.")
   136  	}
   137  
   138  	mutex.Lock()
   139  	if !config.IsLoaded() {
   140  		if err := config.Load(); err != nil {
   141  			ErrLogger.Println(err)
   142  			mutex.Unlock()
   143  			return &Error{ExitInvalidConfig, err}
   144  		}
   145  	}
   146  
   147  	// Performance improvements by loading critical config entries beforehand
   148  	cacheCriticalConfig()
   149  
   150  	lang.LoadDefault()
   151  	lang.LoadAllAvailableLanguages()
   152  
   153  	if config.GetBool("database.autoMigrate") && config.GetString("database.connection") != "none" {
   154  		database.Migrate()
   155  	}
   156  
   157  	router = newRouter()
   158  	routeRegistrer(router)
   159  	regexCache = nil // Clear regex cache
   160  	return startServer(router)
   161  }
   162  
   163  func cacheCriticalConfig() {
   164  	maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024)
   165  	defaultLanguage = config.GetString("app.defaultLanguage")
   166  	protocol = config.GetString("server.protocol")
   167  }
   168  
   169  // EnableMaintenance replace the main server handler with the "Service Unavailable" handler.
   170  func EnableMaintenance() {
   171  	mutex.Lock()
   172  	server.Handler = getMaintenanceHandler()
   173  	maintenanceEnabled = true
   174  	mutex.Unlock()
   175  }
   176  
   177  // DisableMaintenance replace the main server handler with the original router.
   178  func DisableMaintenance() {
   179  	mutex.Lock()
   180  	server.Handler = router
   181  	maintenanceEnabled = false
   182  	mutex.Unlock()
   183  }
   184  
   185  // IsMaintenanceEnabled return true if the server is currently in maintenance mode.
   186  func IsMaintenanceEnabled() bool {
   187  	mutex.RLock()
   188  	defer mutex.RUnlock()
   189  	return maintenanceEnabled
   190  }
   191  
   192  // GetRoute get a named route.
   193  // Returns nil if the route doesn't exist.
   194  func GetRoute(name string) *Route {
   195  	mutex.Lock()
   196  	defer mutex.Unlock()
   197  	return router.namedRoutes[name]
   198  }
   199  
   200  func getMaintenanceHandler() http.Handler {
   201  	once.Do(func() {
   202  		maintenanceHandler = http.HandlerFunc(func(resp http.ResponseWriter, request *http.Request) {
   203  			resp.WriteHeader(http.StatusServiceUnavailable)
   204  		})
   205  	})
   206  	return maintenanceHandler
   207  }
   208  
   209  // Stop gracefully shuts down the server without interrupting any
   210  // active connections.
   211  //
   212  // Make sure the program doesn't exit and waits instead for Stop to return.
   213  //
   214  // Stop does not attempt to close nor wait for hijacked
   215  // connections such as WebSockets. The caller of Stop should
   216  // separately notify such long-lived connections of shutdown and wait
   217  // for them to close, if desired.
   218  func Stop() {
   219  	mutex.Lock()
   220  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   221  	defer cancel()
   222  	stop(ctx)
   223  	if sigChannel != nil {
   224  		hookChannel <- struct{}{} // Clear shutdown hook
   225  		<-hookChannel
   226  		sigChannel = nil
   227  	}
   228  	mutex.Unlock()
   229  }
   230  
   231  func stop(ctx context.Context) error {
   232  	var err error
   233  	if server != nil {
   234  		err = server.Shutdown(ctx)
   235  		database.Close()
   236  		server = nil
   237  		router = nil
   238  		ready = false
   239  		maintenanceEnabled = false
   240  		if redirectServer != nil {
   241  			redirectServer.Shutdown(ctx)
   242  			<-tlsStopChannel
   243  			redirectServer = nil
   244  		}
   245  
   246  		for _, hook := range shutdownHooks {
   247  			hook()
   248  		}
   249  		stopChannel <- struct{}{}
   250  	}
   251  	return err
   252  }
   253  
   254  func getHost(protocol string) string {
   255  	var port string
   256  	if protocol == "https" {
   257  		port = "server.httpsPort"
   258  	} else {
   259  		port = "server.port"
   260  	}
   261  	return config.GetString("server.host") + ":" + strconv.Itoa(config.GetInt(port))
   262  }
   263  
   264  func getAddress(protocol string) string {
   265  	var shouldShowPort bool
   266  	var port string
   267  	if protocol == "https" {
   268  		p := config.GetInt("server.httpsPort")
   269  		port = strconv.Itoa(p)
   270  		shouldShowPort = p != 443
   271  	} else {
   272  		p := config.GetInt("server.port")
   273  		port = strconv.Itoa(p)
   274  		shouldShowPort = p != 80
   275  	}
   276  	host := config.GetString("server.domain")
   277  	if len(host) == 0 {
   278  		host = config.GetString("server.host")
   279  	}
   280  
   281  	if shouldShowPort {
   282  		host += ":" + port
   283  	}
   284  
   285  	return protocol + "://" + host
   286  }
   287  
   288  // BaseURL returns the base URL of your application.
   289  func BaseURL() string {
   290  	return getAddress(config.GetString("server.protocol"))
   291  }
   292  
   293  func startTLSRedirectServer() {
   294  	httpsAddress := getAddress("https")
   295  	timeout := time.Duration(config.GetInt("server.timeout")) * time.Second
   296  	redirectServer = &http.Server{
   297  		Addr:         getHost("http"),
   298  		WriteTimeout: timeout,
   299  		ReadTimeout:  timeout,
   300  		IdleTimeout:  timeout * 2,
   301  		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   302  			address := httpsAddress + r.URL.Path
   303  			query := r.URL.Query()
   304  			if len(query) != 0 {
   305  				address += "?" + query.Encode()
   306  			}
   307  			http.Redirect(w, r, address, http.StatusPermanentRedirect)
   308  		}),
   309  	}
   310  
   311  	ln, err := net.Listen("tcp", redirectServer.Addr)
   312  	if err != nil {
   313  		ErrLogger.Printf("The TLS redirect server encountered an error: %s\n", err.Error())
   314  		redirectServer = nil
   315  		return
   316  	}
   317  
   318  	ok := ready
   319  	r := redirectServer
   320  
   321  	go func() {
   322  		if ok && r != nil {
   323  			if err := r.Serve(ln); err != nil && err != http.ErrServerClosed {
   324  				ErrLogger.Printf("The TLS redirect server encountered an error: %s\n", err.Error())
   325  				mutex.Lock()
   326  				redirectServer = nil
   327  				ln.Close()
   328  				mutex.Unlock()
   329  				return
   330  			}
   331  		}
   332  		ln.Close()
   333  		tlsStopChannel <- struct{}{}
   334  	}()
   335  }
   336  
   337  func startServer(router *Router) error {
   338  	defer func() {
   339  		<-stopChannel // Wait for stop() to finish before returning
   340  	}()
   341  	timeout := time.Duration(config.GetInt("server.timeout")) * time.Second
   342  	server = &http.Server{
   343  		Addr:         getHost(protocol),
   344  		WriteTimeout: timeout,
   345  		ReadTimeout:  timeout,
   346  		IdleTimeout:  timeout * 2,
   347  		Handler:      router,
   348  	}
   349  
   350  	if config.GetBool("server.maintenance") {
   351  		server.Handler = getMaintenanceHandler()
   352  		maintenanceEnabled = true
   353  	}
   354  
   355  	ln, err := net.Listen("tcp", server.Addr)
   356  	if err != nil {
   357  		ErrLogger.Println(err)
   358  		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   359  		defer cancel()
   360  		stop(ctx)
   361  		mutex.Unlock()
   362  		return &Error{ExitNetworkError, err}
   363  	}
   364  	defer ln.Close()
   365  	registerShutdownHook(stop)
   366  	<-hookChannel
   367  
   368  	ready = true
   369  	if protocol == "https" {
   370  		startTLSRedirectServer()
   371  
   372  		s := server
   373  		mutex.Unlock()
   374  		runStartupHooks()
   375  		if err := s.ServeTLS(ln, config.GetString("server.tls.cert"), config.GetString("server.tls.key")); err != nil && err != http.ErrServerClosed {
   376  			ErrLogger.Println(err)
   377  			Stop()
   378  			return &Error{ExitHTTPError, err}
   379  		}
   380  	} else {
   381  
   382  		s := server
   383  		mutex.Unlock()
   384  		runStartupHooks()
   385  		if err := s.Serve(ln); err != nil && err != http.ErrServerClosed {
   386  			ErrLogger.Println(err)
   387  			Stop()
   388  			return &Error{ExitHTTPError, err}
   389  		}
   390  	}
   391  
   392  	return nil
   393  }
   394  
   395  func runStartupHooks() {
   396  	for _, hook := range startupHooks {
   397  		go hook()
   398  	}
   399  }
   400  
   401  func registerShutdownHook(hook func(context.Context) error) {
   402  	sigChannel = make(chan os.Signal, 1)
   403  	signal.Notify(sigChannel, syscall.SIGINT, syscall.SIGTERM)
   404  
   405  	go func() {
   406  		hookChannel <- struct{}{}
   407  		select {
   408  		case <-hookChannel:
   409  			hookChannel <- struct{}{}
   410  		case <-sigChannel: // Block until SIGINT or SIGTERM received
   411  			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   412  			defer cancel()
   413  
   414  			mutex.Lock()
   415  			sigChannel = nil
   416  			hook(ctx)
   417  			mutex.Unlock()
   418  		}
   419  	}()
   420  }
   421  
   422  // TODO refactor server sartup (use context)