goyave.dev/goyave/v5@v5.0.0-rc9.0.20240517145003-d3f977d0b9f3/server.go (about)

     1  package goyave
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"log"
     8  	"net"
     9  	"net/http"
    10  	"os"
    11  	"os/signal"
    12  	"strconv"
    13  	"sync/atomic"
    14  	"syscall"
    15  	"time"
    16  
    17  	stderrors "errors"
    18  
    19  	"gorm.io/gorm"
    20  	"goyave.dev/goyave/v5/config"
    21  	"goyave.dev/goyave/v5/database"
    22  	"goyave.dev/goyave/v5/lang"
    23  	"goyave.dev/goyave/v5/slog"
    24  	"goyave.dev/goyave/v5/util/errors"
    25  	"goyave.dev/goyave/v5/util/fsutil"
    26  	"goyave.dev/goyave/v5/util/fsutil/osfs"
    27  )
    28  
    29  // serverKey is a context key used to store the server instance into its base context.
    30  type serverKey struct{}
    31  
    32  // Options represent server creation options.
    33  type Options struct {
    34  
    35  	// Config used by the server and propagated to all its components.
    36  	// If no configuration is provided, automatically load
    37  	// the default configuration using `config.Load()`.
    38  	Config *config.Config
    39  
    40  	// Logger used by the server and propagated to all its components.
    41  	// If no logger is provided in the options, uses the default logger.
    42  	Logger *slog.Logger
    43  
    44  	// LangFS the file system from which the language files
    45  	// will be loaded. This file system is expected to contain
    46  	// a `resources/lang` directory.
    47  	// If not provided, uses `osfs.FS` as a default.
    48  	LangFS fsutil.FS
    49  
    50  	// ConnState specifies an optional callback function that is
    51  	// called when a client connection changes state. See the
    52  	// `http.ConnState` type and associated constants for details.
    53  	ConnState func(net.Conn, http.ConnState)
    54  
    55  	// Context optionnally defines a function that returns the base context
    56  	// for the server. It will be used as base context for all incoming requests.
    57  	//
    58  	// The provided `net.Listener` is the specific Listener that's
    59  	// about to start accepting requests.
    60  	//
    61  	// If not given, the default is `context.Background()`.
    62  	//
    63  	// The context returned then has a the server instance added to it as a value.
    64  	// The server can thus be retrieved using `goyave.ServerFromContext(ctx)`.
    65  	//
    66  	// If the context is canceled, the server won't shut down automatically, you are
    67  	// responsible of calling `server.Stop()` if you want this to happen. Otherwise the
    68  	// server will continue serving requests, at the risk of generating "context canceled" errors.
    69  	BaseContext func(net.Listener) context.Context
    70  
    71  	// ConnContext optionally specifies a function that modifies
    72  	// the context used for a new connection `c`. The provided context
    73  	// is derived from the base context and has the server instance value, which can
    74  	// be retrieved using `goyave.ServerFromContext(ctx)`.
    75  	ConnContext func(ctx context.Context, c net.Conn) context.Context
    76  
    77  	// MaxHeaderBytes controls the maximum number of bytes the
    78  	// server will read parsing the request header's keys and
    79  	// values, including the request line. It does not limit the
    80  	// size of the request body.
    81  	// If zero, http.DefaultMaxHeaderBytes is used.
    82  	MaxHeaderBytes int
    83  }
    84  
    85  // Server the central component of a Goyave application.
    86  type Server struct {
    87  	server *http.Server
    88  	config *config.Config
    89  	Lang   *lang.Languages
    90  
    91  	router *Router
    92  	db     *gorm.DB
    93  
    94  	services map[string]Service
    95  
    96  	// Logger the logger for default output
    97  	// Writes to stderr by default.
    98  	Logger *slog.Logger
    99  
   100  	host         string
   101  	baseURL      string
   102  	proxyBaseURL string
   103  
   104  	stopChannel chan struct{}
   105  	sigChannel  chan os.Signal
   106  
   107  	ctx           context.Context
   108  	baseContext   func(net.Listener) context.Context
   109  	startupHooks  []func(*Server)
   110  	shutdownHooks []func(*Server)
   111  
   112  	port int
   113  
   114  	state atomic.Uint32 // 0 -> created, 1 -> preparing, 2 -> ready, 3 -> stopped
   115  }
   116  
   117  // New create a new `Server` using the given options.
   118  func New(opts Options) (*Server, error) {
   119  
   120  	cfg := opts.Config
   121  
   122  	if opts.Config == nil {
   123  		var err error
   124  		cfg, err = config.Load()
   125  		if err != nil {
   126  			return nil, errors.New(err)
   127  		}
   128  	}
   129  
   130  	slogger := opts.Logger
   131  	if slogger == nil {
   132  		slogger = slog.New(slog.NewHandler(cfg.GetBool("app.debug"), os.Stderr))
   133  	}
   134  
   135  	langFS := opts.LangFS
   136  	if langFS == nil {
   137  		langFS = &osfs.FS{}
   138  	}
   139  
   140  	languages := lang.New()
   141  	languages.Default = cfg.GetString("app.defaultLanguage")
   142  	if err := languages.LoadAllAvailableLanguages(langFS); err != nil {
   143  		return nil, err
   144  	}
   145  
   146  	port := cfg.GetInt("server.port")
   147  	host := cfg.GetString("server.host") + ":" + strconv.Itoa(port)
   148  
   149  	server := &Server{
   150  		server: &http.Server{
   151  			Addr:              host,
   152  			WriteTimeout:      time.Duration(cfg.GetInt("server.writeTimeout")) * time.Second,
   153  			ReadTimeout:       time.Duration(cfg.GetInt("server.readTimeout")) * time.Second,
   154  			ReadHeaderTimeout: time.Duration(cfg.GetInt("server.readHeaderTimeout")) * time.Second,
   155  			IdleTimeout:       time.Duration(cfg.GetInt("server.idleTimeout")) * time.Second,
   156  			ConnState:         opts.ConnState,
   157  			ConnContext:       opts.ConnContext,
   158  			MaxHeaderBytes:    opts.MaxHeaderBytes,
   159  		},
   160  		ctx:           context.Background(),
   161  		baseContext:   opts.BaseContext,
   162  		config:        cfg,
   163  		services:      make(map[string]Service),
   164  		Lang:          languages,
   165  		stopChannel:   make(chan struct{}, 1),
   166  		startupHooks:  []func(*Server){},
   167  		shutdownHooks: []func(*Server){},
   168  		host:          cfg.GetString("server.host"),
   169  		port:          port,
   170  		Logger:        slogger,
   171  	}
   172  	server.server.BaseContext = server.internalBaseContext
   173  	server.refreshURLs()
   174  	server.server.ErrorLog = log.New(&errLogWriter{server: server}, "", 0)
   175  
   176  	if cfg.GetString("database.connection") != "none" {
   177  		db, err := database.New(cfg, func() *slog.Logger { return server.Logger })
   178  		if err != nil {
   179  			return nil, errors.New(err)
   180  		}
   181  		server.db = db
   182  	}
   183  
   184  	server.router = NewRouter(server)
   185  	server.server.Handler = server.router
   186  	return server, nil
   187  }
   188  
   189  func (s *Server) internalBaseContext(_ net.Listener) context.Context {
   190  	return s.ctx
   191  }
   192  
   193  func (s *Server) getAddress(cfg *config.Config) string {
   194  	shouldShowPort := s.port != 80
   195  	host := cfg.GetString("server.domain")
   196  	if len(host) == 0 {
   197  		host = cfg.GetString("server.host")
   198  		if host == "0.0.0.0" {
   199  			host = "127.0.0.1"
   200  		}
   201  	}
   202  
   203  	if shouldShowPort {
   204  		host += ":" + strconv.Itoa(s.port)
   205  	}
   206  
   207  	return "http://" + host
   208  }
   209  
   210  func (s *Server) getProxyAddress(cfg *config.Config) string {
   211  	if !cfg.Has("server.proxy.host") {
   212  		return s.getAddress(cfg)
   213  	}
   214  
   215  	var shouldShowPort bool
   216  	proto := cfg.GetString("server.proxy.protocol")
   217  	port := cfg.GetInt("server.proxy.port")
   218  	if proto == "https" {
   219  		shouldShowPort = port != 443
   220  	} else {
   221  		shouldShowPort = port != 80
   222  	}
   223  	host := cfg.GetString("server.proxy.host")
   224  	if shouldShowPort {
   225  		host += ":" + strconv.Itoa(port)
   226  	}
   227  
   228  	return proto + "://" + host + cfg.GetString("server.proxy.base")
   229  }
   230  
   231  func (s *Server) refreshURLs() {
   232  	s.baseURL = s.getAddress(s.config)
   233  	s.proxyBaseURL = s.getProxyAddress(s.config)
   234  }
   235  
   236  // Service returns the service identified by the given name.
   237  // Panics if no service could be found with the given name.
   238  func (s *Server) Service(name string) Service {
   239  	if s, ok := s.services[name]; ok {
   240  		return s
   241  	}
   242  	panic(errors.Errorf("service %q does not exist", name))
   243  }
   244  
   245  // LookupService search for a service by its name. If the service
   246  // identified by the given name exists, it is returned with the `true` boolean.
   247  // Otherwise returns `nil` and `false`.
   248  func (s *Server) LookupService(name string) (Service, bool) {
   249  	service, ok := s.services[name]
   250  	return service, ok
   251  }
   252  
   253  // RegisterService on thise server using its name (returned by `Service.Name()`).
   254  // A service's name should be unique.
   255  // `Service.Init(server)` is called on the given service upon registration.
   256  func (s *Server) RegisterService(service Service) {
   257  	s.services[service.Name()] = service
   258  }
   259  
   260  // Host returns the hostname and port the server is running on.
   261  func (s *Server) Host() string {
   262  	return s.host + ":" + strconv.Itoa(s.port)
   263  }
   264  
   265  // Port returns the port the server is running on.
   266  func (s *Server) Port() int {
   267  	return s.port
   268  }
   269  
   270  // BaseURL returns the base URL of your application.
   271  // If "server.domain" is set in the config, uses it instead
   272  // of an IP address.
   273  func (s *Server) BaseURL() string {
   274  	return s.baseURL
   275  }
   276  
   277  // ProxyBaseURL returns the base URL of your application based on the "server.proxy" configuration.
   278  // This is useful when you want to generate an URL when your application is served behind a reverse proxy.
   279  // If "server.proxy.host" configuration is not set, returns the same value as "BaseURL()".
   280  func (s *Server) ProxyBaseURL() string {
   281  	return s.proxyBaseURL
   282  }
   283  
   284  // IsReady returns true if the server has finished initializing and
   285  // is ready to serve incoming requests.
   286  // This operation is concurrently safe.
   287  func (s *Server) IsReady() bool {
   288  	return s.state.Load() == 2
   289  }
   290  
   291  // RegisterStartupHook to execute some code once the server is ready and running.
   292  // All startup hooks are executed in a single goroutine and in order of registration.
   293  func (s *Server) RegisterStartupHook(hook func(*Server)) {
   294  	s.startupHooks = append(s.startupHooks, hook)
   295  }
   296  
   297  // ClearStartupHooks removes all startup hooks.
   298  func (s *Server) ClearStartupHooks() {
   299  	s.startupHooks = []func(*Server){}
   300  }
   301  
   302  // RegisterShutdownHook to execute some code after the server stopped.
   303  // Shutdown hooks are executed before `Start()` returns and are NOT executed
   304  // in a goroutine, meaning that the shutdown process can be blocked by your
   305  // shutdown hooks. It is your responsibility to implement a timeout mechanism
   306  // inside your hook if necessary.
   307  func (s *Server) RegisterShutdownHook(hook func(*Server)) {
   308  	s.shutdownHooks = append(s.shutdownHooks, hook)
   309  }
   310  
   311  // ClearShutdownHooks removes all shutdown hooks.
   312  func (s *Server) ClearShutdownHooks() {
   313  	s.shutdownHooks = []func(*Server){}
   314  }
   315  
   316  // Config returns the server's config.
   317  func (s *Server) Config() *config.Config {
   318  	return s.config
   319  }
   320  
   321  // DB returns the root database instance. Panics if no
   322  // database connection is set up.
   323  func (s *Server) DB() *gorm.DB {
   324  	if s.db == nil {
   325  		panic(errors.NewSkip("No database connection. Database is set to \"none\" in the config", 3))
   326  	}
   327  	return s.db
   328  }
   329  
   330  // Transaction makes it so all DB requests are run inside a transaction.
   331  //
   332  // Returns the rollback function. When you are done, call this function to
   333  // complete the transaction and roll it back. This will also restore the original
   334  // DB so it can be used again out of the transaction.
   335  //
   336  // This is used for tests. This operation is not concurrently safe.
   337  func (s *Server) Transaction(opts ...*sql.TxOptions) func() {
   338  	if s.db == nil {
   339  		panic(errors.NewSkip("No database connection. Database is set to \"none\" in the config", 3))
   340  	}
   341  	ogDB := s.db
   342  	s.db = s.db.Begin(opts...)
   343  	return func() {
   344  		err := s.db.Rollback().Error
   345  		s.db = ogDB
   346  		if err != nil {
   347  			panic(errors.New(err))
   348  		}
   349  	}
   350  }
   351  
   352  // ReplaceDB manually replace the automatic DB connection.
   353  // If a connection already exists, closes it before discarding it.
   354  // This can be used to create a mock DB in tests. Using this function
   355  // is not recommended outside of tests. Prefer using a custom dialect.
   356  // This operation is not concurrently safe.
   357  func (s *Server) ReplaceDB(dialector gorm.Dialector) error {
   358  	if err := s.CloseDB(); err != nil {
   359  		return err
   360  	}
   361  
   362  	db, err := database.NewFromDialector(s.config, func() *slog.Logger { return s.Logger }, dialector)
   363  	if err != nil {
   364  		return err
   365  	}
   366  
   367  	s.db = db
   368  	return nil
   369  }
   370  
   371  // CloseDB close the database connection if there is one.
   372  // Does nothing and returns `nil` if there is no connection.
   373  func (s *Server) CloseDB() error {
   374  	if s.db == nil {
   375  		return nil
   376  	}
   377  	db, err := s.db.DB()
   378  	if err != nil {
   379  		if stderrors.Is(err, gorm.ErrInvalidDB) {
   380  			return nil
   381  		}
   382  		return errors.New(err)
   383  	}
   384  	return errors.New(db.Close())
   385  }
   386  
   387  // Router returns the root router.
   388  func (s *Server) Router() *Router {
   389  	return s.router
   390  }
   391  
   392  // Start the server. This operation is blocking and returns when the server is closed.
   393  func (s *Server) Start() error {
   394  	swapped := s.state.CompareAndSwap(0, 1)
   395  	if !swapped {
   396  		return errors.New("server was already started")
   397  	}
   398  
   399  	defer func() {
   400  		s.state.Store(3)
   401  		// Notify the shutdown is complete so Stop() can return
   402  		s.stopChannel <- struct{}{}
   403  		close(s.stopChannel)
   404  	}()
   405  
   406  	ln, err := net.Listen("tcp", s.server.Addr)
   407  	if err != nil {
   408  		return errors.New(err)
   409  	}
   410  	baseCtx := context.Background()
   411  	if s.baseContext != nil {
   412  		baseCtx = s.baseContext(ln)
   413  		if baseCtx == nil {
   414  			panic("server options BaseContext returned a nil context")
   415  		}
   416  	}
   417  	s.ctx = context.WithValue(baseCtx, serverKey{}, s)
   418  
   419  	select {
   420  	case <-s.ctx.Done():
   421  		return errors.New("cannot start the server, context is canceled")
   422  	default:
   423  	}
   424  
   425  	s.port = ln.Addr().(*net.TCPAddr).Port
   426  	s.refreshURLs()
   427  	defer func() {
   428  		for _, hook := range s.shutdownHooks {
   429  			hook(s)
   430  		}
   431  		if err := s.CloseDB(); err != nil {
   432  			s.Logger.Error(err)
   433  		}
   434  	}()
   435  
   436  	s.state.Store(2)
   437  
   438  	go func(s *Server) {
   439  		if s.IsReady() {
   440  			// We check if the server is ready to prevent startup hook execution
   441  			// if `Serve` returned an error before the goroutine started
   442  			for _, hook := range s.startupHooks {
   443  				hook(s)
   444  			}
   445  		}
   446  	}(s)
   447  	if err := s.server.Serve(ln); err != nil && !stderrors.Is(err, http.ErrServerClosed) {
   448  		s.state.Store(3)
   449  		return errors.New(err)
   450  	}
   451  	return nil
   452  }
   453  
   454  // RegisterRoutes creates a new Router for this Server and runs the given `routeRegistrer`.
   455  func (s *Server) RegisterRoutes(routeRegistrer func(*Server, *Router)) {
   456  	routeRegistrer(s, s.router)
   457  	s.router.ClearRegexCache()
   458  }
   459  
   460  // Stop gracefully shuts down the server without interrupting any
   461  // active connections.
   462  //
   463  // `Stop()` does not attempt to close nor wait for hijacked
   464  // connections such as WebSockets. The caller of `Stop` should
   465  // separately notify such long-lived connections of shutdown and wait
   466  // for them to close, if desired. This can be done using shutdown hooks.
   467  //
   468  // If registered, the OS signal channel is closed.
   469  //
   470  // Make sure the program doesn't exit before `Stop()` returns.
   471  //
   472  // After being stopped, a `Server` is not meant to be re-used.
   473  //
   474  // This function can be called from any goroutine and is concurrently safe.
   475  // Calling this function several times is safe. Calls after the first one are no-op.
   476  func (s *Server) Stop() {
   477  	state := s.state.Swap(3)
   478  	if state == 0 || state == 3 {
   479  		// Start has not been called or Stop has already been called, do nothing
   480  		return
   481  	}
   482  	if s.sigChannel != nil {
   483  		signal.Stop(s.sigChannel)
   484  		close(s.sigChannel)
   485  	}
   486  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   487  	defer cancel()
   488  	err := s.server.Shutdown(ctx)
   489  	if err != nil {
   490  		s.Logger.Error(errors.NewSkip(err, 3))
   491  	}
   492  
   493  	<-s.stopChannel // Wait for stop channel before returning
   494  }
   495  
   496  // RegisterSignalHook creates a channel listening on SIGINT and SIGTERM. When receiving such
   497  // signal, the server is stopped automatically and the listener on these signals is removed.
   498  func (s *Server) RegisterSignalHook() {
   499  
   500  	// Sometimes users may not want to have a sigChannel setup
   501  	// also we don't want it in tests
   502  	// users will have to manually call this function if they want the shutdown on signal feature
   503  
   504  	s.sigChannel = make(chan os.Signal, 64)
   505  	signal.Notify(s.sigChannel, syscall.SIGINT, syscall.SIGTERM)
   506  
   507  	go func() {
   508  		_, ok := <-s.sigChannel
   509  		if ok {
   510  			s.Stop()
   511  		}
   512  	}()
   513  }
   514  
   515  // errLogWriter is a proxy io.Writer that pipes into the server logger.
   516  // This is used so the error logger (type `*log.Logger`) of the underlying
   517  // std HTTP server write to the same logger as the rest of the application.
   518  type errLogWriter struct {
   519  	server *Server
   520  }
   521  
   522  func (w errLogWriter) Write(p []byte) (n int, err error) {
   523  	w.server.Logger.Error(fmt.Errorf("%s", p))
   524  	return len(p), nil
   525  }
   526  
   527  // ServerFromContext returns the `*goyave.Server` stored in the given context or `nil`.
   528  // This is safe to call using any context retrieved from incoming HTTP requests as this value
   529  // is automatically injected when the server is created.
   530  func ServerFromContext(ctx context.Context) *Server {
   531  	s, ok := ctx.Value(serverKey{}).(*Server)
   532  	if !ok {
   533  		return nil
   534  	}
   535  	return s
   536  }