goyave.dev/goyave/v4@v4.4.11/goyave.go (about)

     1  package goyave
     2  
     3  import (
     4  	"context"
     5  	"log"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"os/signal"
    10  	"strconv"
    11  	"sync"
    12  	"syscall"
    13  	"time"
    14  
    15  	"goyave.dev/goyave/v4/config"
    16  	"goyave.dev/goyave/v4/database"
    17  	"goyave.dev/goyave/v4/lang"
    18  )
    19  
    20  var (
    21  	server             *http.Server
    22  	redirectServer     *http.Server
    23  	router             *Router
    24  	maintenanceHandler http.Handler
    25  	sigChannel         chan os.Signal
    26  	tlsStopChannel     = make(chan struct{}, 1)
    27  	stopChannel        = make(chan struct{}, 1)
    28  	hookChannel        = make(chan struct{}, 1)
    29  
    30  	// Critical config entries (cached for better performance)
    31  	protocol        string
    32  	maxPayloadSize  int64
    33  	defaultLanguage string
    34  
    35  	startupHooks       []func()
    36  	shutdownHooks      []func()
    37  	ready              = false
    38  	maintenanceEnabled = false
    39  	mutex              = &sync.RWMutex{}
    40  	once               sync.Once
    41  
    42  	// Logger the logger for default output
    43  	// Writes to stdout by default.
    44  	Logger = log.New(os.Stdout, "", log.LstdFlags)
    45  
    46  	// AccessLogger the logger for access. This logger
    47  	// is used by the logging middleware.
    48  	// Writes to stdout by default.
    49  	AccessLogger = log.New(os.Stdout, "", 0)
    50  
    51  	// ErrLogger the logger in which errors and stacktraces are written.
    52  	// Writes to stderr by default.
    53  	ErrLogger = log.New(os.Stderr, "", log.LstdFlags)
    54  )
    55  
    56  const (
    57  	// ExitInvalidConfig the exit code returned when the config
    58  	// validation doesn't pass.
    59  	ExitInvalidConfig = 3
    60  
    61  	// ExitNetworkError the exit code returned when an error
    62  	// occurs when opening the network listener
    63  	ExitNetworkError = 4
    64  
    65  	// ExitHTTPError the exit code returned when an error
    66  	// occurs in the HTTP server (port already in use for example)
    67  	ExitHTTPError = 5
    68  )
    69  
    70  // Error wrapper for errors directely related to the server itself.
    71  // Contains an exit code and the original error.
    72  type Error struct {
    73  	Err      error
    74  	ExitCode int
    75  }
    76  
    77  func (e *Error) Error() string {
    78  	return e.Err.Error()
    79  }
    80  
    81  // IsReady returns true if the server has finished initializing and
    82  // is ready to serve incoming requests.
    83  func IsReady() bool {
    84  	mutex.RLock()
    85  	defer mutex.RUnlock()
    86  	return ready
    87  }
    88  
    89  // RegisterStartupHook to execute some code once the server is ready and running.
    90  func RegisterStartupHook(hook func()) {
    91  	mutex.Lock()
    92  	startupHooks = append(startupHooks, hook)
    93  	mutex.Unlock()
    94  }
    95  
    96  // ClearStartupHooks removes all startup hooks.
    97  func ClearStartupHooks() {
    98  	mutex.Lock()
    99  	startupHooks = []func(){}
   100  	mutex.Unlock()
   101  }
   102  
   103  // RegisterShutdownHook to execute some code after the server stopped.
   104  // Shutdown hooks are executed before goyave.Start() returns.
   105  func RegisterShutdownHook(hook func()) {
   106  	mutex.Lock()
   107  	shutdownHooks = append(shutdownHooks, hook)
   108  	mutex.Unlock()
   109  }
   110  
   111  // ClearShutdownHooks removes all shutdown hooks.
   112  func ClearShutdownHooks() {
   113  	mutex.Lock()
   114  	shutdownHooks = []func(){}
   115  	mutex.Unlock()
   116  }
   117  
   118  // Start starts the web server.
   119  // The routeRegistrer parameter is a function aimed at registering all your routes and middleware.
   120  //
   121  //	import (
   122  //	    "goyave.dev/goyave/v4"
   123  //	    "github.com/username/projectname/route"
   124  //	)
   125  //
   126  //	func main() {
   127  //	    if err := goyave.Start(route.Register); err != nil {
   128  //	        os.Exit(err.(*goyave.Error).ExitCode)
   129  //	    }
   130  //	}
   131  //
   132  // Errors returned can be safely type-asserted to "*goyave.Error".
   133  // Panics if the server is already running.
   134  func Start(routeRegistrer func(*Router)) error {
   135  	if IsReady() {
   136  		ErrLogger.Panicf("Server is already running.")
   137  	}
   138  
   139  	mutex.Lock()
   140  	if !config.IsLoaded() {
   141  		if err := config.Load(); err != nil {
   142  			ErrLogger.Println(err)
   143  			mutex.Unlock()
   144  			return &Error{err, ExitInvalidConfig}
   145  		}
   146  	}
   147  
   148  	// Performance improvements by loading critical config entries beforehand
   149  	cacheCriticalConfig()
   150  
   151  	lang.LoadDefault()
   152  	lang.LoadAllAvailableLanguages()
   153  
   154  	if config.GetBool("database.autoMigrate") && config.GetString("database.connection") != "none" {
   155  		database.Migrate()
   156  	}
   157  
   158  	router = NewRouter()
   159  	routeRegistrer(router)
   160  	router.ClearRegexCache()
   161  	return startServer(router)
   162  }
   163  
   164  func cacheCriticalConfig() {
   165  	maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024)
   166  	defaultLanguage = config.GetString("app.defaultLanguage")
   167  	protocol = config.GetString("server.protocol")
   168  }
   169  
   170  // EnableMaintenance replace the main server handler with the "Service Unavailable" handler.
   171  func EnableMaintenance() {
   172  	mutex.Lock()
   173  	server.Handler = getMaintenanceHandler()
   174  	maintenanceEnabled = true
   175  	mutex.Unlock()
   176  }
   177  
   178  // DisableMaintenance replace the main server handler with the original router.
   179  func DisableMaintenance() {
   180  	mutex.Lock()
   181  	server.Handler = router
   182  	maintenanceEnabled = false
   183  	mutex.Unlock()
   184  }
   185  
   186  // IsMaintenanceEnabled return true if the server is currently in maintenance mode.
   187  func IsMaintenanceEnabled() bool {
   188  	mutex.RLock()
   189  	defer mutex.RUnlock()
   190  	return maintenanceEnabled
   191  }
   192  
   193  // GetRoute get a named route.
   194  // Returns nil if the route doesn't exist.
   195  func GetRoute(name string) *Route {
   196  	mutex.Lock()
   197  	defer mutex.Unlock()
   198  	return router.namedRoutes[name]
   199  }
   200  
   201  func getMaintenanceHandler() http.Handler {
   202  	once.Do(func() {
   203  		maintenanceHandler = http.HandlerFunc(func(resp http.ResponseWriter, request *http.Request) {
   204  			resp.WriteHeader(http.StatusServiceUnavailable)
   205  		})
   206  	})
   207  	return maintenanceHandler
   208  }
   209  
   210  // Stop gracefully shuts down the server without interrupting any
   211  // active connections.
   212  //
   213  // Make sure the program doesn't exit and waits instead for Stop to return.
   214  //
   215  // Stop does not attempt to close nor wait for hijacked
   216  // connections such as WebSockets. The caller of Stop should
   217  // separately notify such long-lived connections of shutdown and wait
   218  // for them to close, if desired.
   219  func Stop() {
   220  	mutex.Lock()
   221  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   222  	defer cancel()
   223  	stop(ctx)
   224  	if sigChannel != nil {
   225  		hookChannel <- struct{}{} // Clear shutdown hook
   226  		<-hookChannel
   227  		sigChannel = nil
   228  	}
   229  	mutex.Unlock()
   230  }
   231  
   232  func stop(ctx context.Context) error {
   233  	var err error
   234  	if server != nil {
   235  		err = server.Shutdown(ctx)
   236  		database.Close()
   237  		server = nil
   238  		router = nil
   239  		ready = false
   240  		maintenanceEnabled = false
   241  		if redirectServer != nil {
   242  			redirectServer.Shutdown(ctx)
   243  			<-tlsStopChannel
   244  			redirectServer = nil
   245  		}
   246  
   247  		for _, hook := range shutdownHooks {
   248  			hook()
   249  		}
   250  		stopChannel <- struct{}{}
   251  	}
   252  	return err
   253  }
   254  
   255  func getHost(protocol string) string {
   256  	var port string
   257  	if protocol == "https" {
   258  		port = "server.httpsPort"
   259  	} else {
   260  		port = "server.port"
   261  	}
   262  	return config.GetString("server.host") + ":" + strconv.Itoa(config.GetInt(port))
   263  }
   264  
   265  func getAddress(protocol string) string {
   266  	var shouldShowPort bool
   267  	var port int
   268  	if protocol == "https" {
   269  		port = config.GetInt("server.httpsPort")
   270  		shouldShowPort = port != 443
   271  	} else {
   272  		port = config.GetInt("server.port")
   273  		shouldShowPort = port != 80
   274  	}
   275  	host := config.GetString("server.domain")
   276  	if len(host) == 0 {
   277  		host = config.GetString("server.host")
   278  		if host == "0.0.0.0" {
   279  			host = "127.0.0.1"
   280  		}
   281  	}
   282  
   283  	if shouldShowPort {
   284  		host += ":" + strconv.Itoa(port)
   285  	}
   286  
   287  	return protocol + "://" + host
   288  }
   289  
   290  // BaseURL returns the base URL of your application.
   291  func BaseURL() string {
   292  	if protocol == "" {
   293  		protocol = config.GetString("server.protocol")
   294  	}
   295  	return getAddress(protocol)
   296  }
   297  
   298  // ProxyBaseURL returns the base URL of your application based on the "server.proxy" configuration.
   299  // This is useful when you want to generate an URL when your application is served behind a reverse proxy.
   300  // If "server.proxy.host" configuration is not set, returns the same value as "BaseURL()".
   301  func ProxyBaseURL() string {
   302  	if !config.Has("server.proxy.host") {
   303  		return BaseURL()
   304  	}
   305  
   306  	var shouldShowPort bool
   307  	proto := config.GetString("server.proxy.protocol")
   308  	port := config.GetInt("server.proxy.port")
   309  	if proto == "https" {
   310  		shouldShowPort = port != 443
   311  	} else {
   312  		shouldShowPort = port != 80
   313  	}
   314  	host := config.GetString("server.proxy.host")
   315  	if shouldShowPort {
   316  		host += ":" + strconv.Itoa(port)
   317  	}
   318  
   319  	return proto + "://" + host + config.GetString("server.proxy.base")
   320  }
   321  
   322  func startTLSRedirectServer() {
   323  	httpsAddress := getAddress("https")
   324  	timeout := time.Duration(config.GetInt("server.timeout")) * time.Second
   325  	redirectServer = &http.Server{
   326  		Addr:         getHost("http"),
   327  		WriteTimeout: timeout,
   328  		ReadTimeout:  timeout,
   329  		IdleTimeout:  timeout * 2,
   330  		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   331  			address := httpsAddress + r.URL.Path
   332  			query := r.URL.Query()
   333  			if len(query) != 0 {
   334  				address += "?" + query.Encode()
   335  			}
   336  			http.Redirect(w, r, address, http.StatusPermanentRedirect)
   337  		}),
   338  	}
   339  
   340  	ln, err := net.Listen("tcp", redirectServer.Addr)
   341  	if err != nil {
   342  		ErrLogger.Printf("The TLS redirect server encountered an error: %s\n", err.Error())
   343  		redirectServer = nil
   344  		return
   345  	}
   346  
   347  	ok := ready
   348  	r := redirectServer
   349  
   350  	go func() {
   351  		if ok && r != nil {
   352  			if err := r.Serve(ln); err != nil && err != http.ErrServerClosed {
   353  				ErrLogger.Printf("The TLS redirect server encountered an error: %s\n", err.Error())
   354  				mutex.Lock()
   355  				redirectServer = nil
   356  				ln.Close()
   357  				mutex.Unlock()
   358  				return
   359  			}
   360  		}
   361  		ln.Close()
   362  		tlsStopChannel <- struct{}{}
   363  	}()
   364  }
   365  
   366  func startServer(router *Router) error {
   367  	defer func() {
   368  		<-stopChannel // Wait for stop() to finish before returning
   369  	}()
   370  	timeout := time.Duration(config.GetInt("server.timeout")) * time.Second
   371  	server = &http.Server{
   372  		Addr:         getHost(protocol),
   373  		WriteTimeout: timeout,
   374  		ReadTimeout:  timeout,
   375  		IdleTimeout:  timeout * 2,
   376  		Handler:      router,
   377  	}
   378  
   379  	if config.GetBool("server.maintenance") {
   380  		server.Handler = getMaintenanceHandler()
   381  		maintenanceEnabled = true
   382  	}
   383  
   384  	ln, err := net.Listen("tcp", server.Addr)
   385  	if err != nil {
   386  		ErrLogger.Println(err)
   387  		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   388  		defer cancel()
   389  		stop(ctx)
   390  		mutex.Unlock()
   391  		return &Error{err, ExitNetworkError}
   392  	}
   393  	defer ln.Close()
   394  
   395  	readyChan := make(chan struct{})
   396  	registerShutdownHook(readyChan, stop)
   397  	<-readyChan
   398  	close(readyChan)
   399  
   400  	ready = true
   401  	if protocol == "https" {
   402  		startTLSRedirectServer()
   403  
   404  		s := server
   405  		mutex.Unlock()
   406  		runStartupHooks()
   407  		if err := s.ServeTLS(ln, config.GetString("server.tls.cert"), config.GetString("server.tls.key")); err != nil && err != http.ErrServerClosed {
   408  			ErrLogger.Println(err)
   409  			Stop()
   410  			return &Error{err, ExitHTTPError}
   411  		}
   412  	} else {
   413  
   414  		s := server
   415  		mutex.Unlock()
   416  		runStartupHooks()
   417  		if err := s.Serve(ln); err != nil && err != http.ErrServerClosed {
   418  			ErrLogger.Println(err)
   419  			Stop()
   420  			return &Error{err, ExitHTTPError}
   421  		}
   422  	}
   423  
   424  	return nil
   425  }
   426  
   427  func runStartupHooks() {
   428  	for _, hook := range startupHooks {
   429  		go hook()
   430  	}
   431  }
   432  
   433  func registerShutdownHook(readyChan chan struct{}, hook func(context.Context) error) {
   434  	sigChannel = make(chan os.Signal, 64)
   435  	signal.Notify(sigChannel, syscall.SIGINT, syscall.SIGTERM)
   436  
   437  	go func() {
   438  		readyChan <- struct{}{}
   439  		select {
   440  		case <-hookChannel:
   441  			hookChannel <- struct{}{}
   442  		case <-sigChannel: // Block until SIGINT or SIGTERM received
   443  			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   444  			defer cancel()
   445  
   446  			mutex.Lock()
   447  			sigChannel = nil
   448  			hook(ctx)
   449  			mutex.Unlock()
   450  		}
   451  	}()
   452  }
   453  
   454  // TODO refactor server sartup (use context)