github.com/System-Glitch/goyave/v3@v3.6.1-0.20210226143142-ac2fe42ee80e/goyave.go (about) 1 package goyave 2 3 import ( 4 "context" 5 "log" 6 "net" 7 "net/http" 8 "os" 9 "os/signal" 10 "strconv" 11 "sync" 12 "syscall" 13 "time" 14 15 "github.com/System-Glitch/goyave/v3/config" 16 "github.com/System-Glitch/goyave/v3/database" 17 "github.com/System-Glitch/goyave/v3/lang" 18 ) 19 20 var ( 21 server *http.Server 22 redirectServer *http.Server 23 router *Router 24 maintenanceHandler http.Handler 25 sigChannel chan os.Signal 26 tlsStopChannel chan struct{} = make(chan struct{}, 1) 27 stopChannel chan struct{} = make(chan struct{}, 1) 28 hookChannel chan struct{} = make(chan struct{}, 1) 29 30 // Critical config entries (cached for better performance) 31 protocol string 32 maxPayloadSize int64 33 defaultLanguage string 34 35 startupHooks []func() 36 shutdownHooks []func() 37 ready bool = false 38 maintenanceEnabled bool = false 39 mutex = &sync.RWMutex{} 40 once sync.Once 41 42 // Logger the logger for default output 43 // Writes to stdout by default. 44 Logger *log.Logger = log.New(os.Stdout, "", log.LstdFlags) 45 46 // AccessLogger the logger for access. This logger 47 // is used by the logging middleware. 48 // Writes to stdout by default. 49 AccessLogger *log.Logger = log.New(os.Stdout, "", 0) 50 51 // ErrLogger the logger in which errors and stacktraces are written. 52 // Writes to stderr by default. 53 ErrLogger *log.Logger = log.New(os.Stderr, "", log.LstdFlags) 54 ) 55 56 const ( 57 // ExitInvalidConfig the exit code returned when the config 58 // validation doesn't pass. 59 ExitInvalidConfig = 3 60 61 // ExitNetworkError the exit code returned when an error 62 // occurs when opening the network listener 63 ExitNetworkError = 4 64 65 // ExitHTTPError the exit code returned when an error 66 // occurs in the HTTP server (port already in use for example) 67 ExitHTTPError = 5 68 ) 69 70 // Error wrapper for errors directely related to the server itself. 71 // Contains an exit code and the original error. 72 type Error struct { 73 ExitCode int 74 Err error 75 } 76 77 func (e *Error) Error() string { 78 return e.Err.Error() 79 } 80 81 // IsReady returns true if the server has finished initializing and 82 // is ready to serve incoming requests. 83 func IsReady() bool { 84 mutex.RLock() 85 defer mutex.RUnlock() 86 return ready 87 } 88 89 // RegisterStartupHook to execute some code once the server is ready and running. 90 func RegisterStartupHook(hook func()) { 91 mutex.Lock() 92 startupHooks = append(startupHooks, hook) 93 mutex.Unlock() 94 } 95 96 // ClearStartupHooks removes all startup hooks. 97 func ClearStartupHooks() { 98 mutex.Lock() 99 startupHooks = []func(){} 100 mutex.Unlock() 101 } 102 103 // RegisterShutdownHook to execute some code after the server stopped. 104 // Shutdown hooks are executed before goyave.Start() returns. 105 func RegisterShutdownHook(hook func()) { 106 mutex.Lock() 107 shutdownHooks = append(shutdownHooks, hook) 108 mutex.Unlock() 109 } 110 111 // ClearShutdownHooks removes all shutdown hooks. 112 func ClearShutdownHooks() { 113 mutex.Lock() 114 shutdownHooks = []func(){} 115 mutex.Unlock() 116 } 117 118 // Start starts the web server. 119 // The routeRegistrer parameter is a function aimed at registering all your routes and middleware. 120 // import ( 121 // "github.com/System-Glitch/goyave/v3" 122 // "github.com/username/projectname/route" 123 // ) 124 // 125 // func main() { 126 // if err := goyave.Start(route.Register); err != nil { 127 // os.Exit(err.(*goyave.Error).ExitCode) 128 // } 129 // } 130 // 131 // Errors returned can be safely type-asserted to "*goyave.Error". 132 // Panics if the server is already running. 133 func Start(routeRegistrer func(*Router)) error { 134 if IsReady() { 135 ErrLogger.Panicf("Server is already running.") 136 } 137 138 mutex.Lock() 139 if !config.IsLoaded() { 140 if err := config.Load(); err != nil { 141 ErrLogger.Println(err) 142 mutex.Unlock() 143 return &Error{ExitInvalidConfig, err} 144 } 145 } 146 147 // Performance improvements by loading critical config entries beforehand 148 cacheCriticalConfig() 149 150 lang.LoadDefault() 151 lang.LoadAllAvailableLanguages() 152 153 if config.GetBool("database.autoMigrate") && config.GetString("database.connection") != "none" { 154 database.Migrate() 155 } 156 157 router = newRouter() 158 routeRegistrer(router) 159 regexCache = nil // Clear regex cache 160 return startServer(router) 161 } 162 163 func cacheCriticalConfig() { 164 maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024) 165 defaultLanguage = config.GetString("app.defaultLanguage") 166 protocol = config.GetString("server.protocol") 167 } 168 169 // EnableMaintenance replace the main server handler with the "Service Unavailable" handler. 170 func EnableMaintenance() { 171 mutex.Lock() 172 server.Handler = getMaintenanceHandler() 173 maintenanceEnabled = true 174 mutex.Unlock() 175 } 176 177 // DisableMaintenance replace the main server handler with the original router. 178 func DisableMaintenance() { 179 mutex.Lock() 180 server.Handler = router 181 maintenanceEnabled = false 182 mutex.Unlock() 183 } 184 185 // IsMaintenanceEnabled return true if the server is currently in maintenance mode. 186 func IsMaintenanceEnabled() bool { 187 mutex.RLock() 188 defer mutex.RUnlock() 189 return maintenanceEnabled 190 } 191 192 // GetRoute get a named route. 193 // Returns nil if the route doesn't exist. 194 func GetRoute(name string) *Route { 195 mutex.Lock() 196 defer mutex.Unlock() 197 return router.namedRoutes[name] 198 } 199 200 func getMaintenanceHandler() http.Handler { 201 once.Do(func() { 202 maintenanceHandler = http.HandlerFunc(func(resp http.ResponseWriter, request *http.Request) { 203 resp.WriteHeader(http.StatusServiceUnavailable) 204 }) 205 }) 206 return maintenanceHandler 207 } 208 209 // Stop gracefully shuts down the server without interrupting any 210 // active connections. 211 // 212 // Make sure the program doesn't exit and waits instead for Stop to return. 213 // 214 // Stop does not attempt to close nor wait for hijacked 215 // connections such as WebSockets. The caller of Stop should 216 // separately notify such long-lived connections of shutdown and wait 217 // for them to close, if desired. 218 func Stop() { 219 mutex.Lock() 220 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 221 defer cancel() 222 stop(ctx) 223 if sigChannel != nil { 224 hookChannel <- struct{}{} // Clear shutdown hook 225 <-hookChannel 226 sigChannel = nil 227 } 228 mutex.Unlock() 229 } 230 231 func stop(ctx context.Context) error { 232 var err error 233 if server != nil { 234 err = server.Shutdown(ctx) 235 database.Close() 236 server = nil 237 router = nil 238 ready = false 239 maintenanceEnabled = false 240 if redirectServer != nil { 241 redirectServer.Shutdown(ctx) 242 <-tlsStopChannel 243 redirectServer = nil 244 } 245 246 for _, hook := range shutdownHooks { 247 hook() 248 } 249 stopChannel <- struct{}{} 250 } 251 return err 252 } 253 254 func getHost(protocol string) string { 255 var port string 256 if protocol == "https" { 257 port = "server.httpsPort" 258 } else { 259 port = "server.port" 260 } 261 return config.GetString("server.host") + ":" + strconv.Itoa(config.GetInt(port)) 262 } 263 264 func getAddress(protocol string) string { 265 var shouldShowPort bool 266 var port string 267 if protocol == "https" { 268 p := config.GetInt("server.httpsPort") 269 port = strconv.Itoa(p) 270 shouldShowPort = p != 443 271 } else { 272 p := config.GetInt("server.port") 273 port = strconv.Itoa(p) 274 shouldShowPort = p != 80 275 } 276 host := config.GetString("server.domain") 277 if len(host) == 0 { 278 host = config.GetString("server.host") 279 } 280 281 if shouldShowPort { 282 host += ":" + port 283 } 284 285 return protocol + "://" + host 286 } 287 288 // BaseURL returns the base URL of your application. 289 func BaseURL() string { 290 return getAddress(config.GetString("server.protocol")) 291 } 292 293 func startTLSRedirectServer() { 294 httpsAddress := getAddress("https") 295 timeout := time.Duration(config.GetInt("server.timeout")) * time.Second 296 redirectServer = &http.Server{ 297 Addr: getHost("http"), 298 WriteTimeout: timeout, 299 ReadTimeout: timeout, 300 IdleTimeout: timeout * 2, 301 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 302 address := httpsAddress + r.URL.Path 303 query := r.URL.Query() 304 if len(query) != 0 { 305 address += "?" + query.Encode() 306 } 307 http.Redirect(w, r, address, http.StatusPermanentRedirect) 308 }), 309 } 310 311 ln, err := net.Listen("tcp", redirectServer.Addr) 312 if err != nil { 313 ErrLogger.Printf("The TLS redirect server encountered an error: %s\n", err.Error()) 314 redirectServer = nil 315 return 316 } 317 318 ok := ready 319 r := redirectServer 320 321 go func() { 322 if ok && r != nil { 323 if err := r.Serve(ln); err != nil && err != http.ErrServerClosed { 324 ErrLogger.Printf("The TLS redirect server encountered an error: %s\n", err.Error()) 325 mutex.Lock() 326 redirectServer = nil 327 ln.Close() 328 mutex.Unlock() 329 return 330 } 331 } 332 ln.Close() 333 tlsStopChannel <- struct{}{} 334 }() 335 } 336 337 func startServer(router *Router) error { 338 defer func() { 339 <-stopChannel // Wait for stop() to finish before returning 340 }() 341 timeout := time.Duration(config.GetInt("server.timeout")) * time.Second 342 server = &http.Server{ 343 Addr: getHost(protocol), 344 WriteTimeout: timeout, 345 ReadTimeout: timeout, 346 IdleTimeout: timeout * 2, 347 Handler: router, 348 } 349 350 if config.GetBool("server.maintenance") { 351 server.Handler = getMaintenanceHandler() 352 maintenanceEnabled = true 353 } 354 355 ln, err := net.Listen("tcp", server.Addr) 356 if err != nil { 357 ErrLogger.Println(err) 358 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 359 defer cancel() 360 stop(ctx) 361 mutex.Unlock() 362 return &Error{ExitNetworkError, err} 363 } 364 defer ln.Close() 365 registerShutdownHook(stop) 366 <-hookChannel 367 368 ready = true 369 if protocol == "https" { 370 startTLSRedirectServer() 371 372 s := server 373 mutex.Unlock() 374 runStartupHooks() 375 if err := s.ServeTLS(ln, config.GetString("server.tls.cert"), config.GetString("server.tls.key")); err != nil && err != http.ErrServerClosed { 376 ErrLogger.Println(err) 377 Stop() 378 return &Error{ExitHTTPError, err} 379 } 380 } else { 381 382 s := server 383 mutex.Unlock() 384 runStartupHooks() 385 if err := s.Serve(ln); err != nil && err != http.ErrServerClosed { 386 ErrLogger.Println(err) 387 Stop() 388 return &Error{ExitHTTPError, err} 389 } 390 } 391 392 return nil 393 } 394 395 func runStartupHooks() { 396 for _, hook := range startupHooks { 397 go hook() 398 } 399 } 400 401 func registerShutdownHook(hook func(context.Context) error) { 402 sigChannel = make(chan os.Signal, 1) 403 signal.Notify(sigChannel, syscall.SIGINT, syscall.SIGTERM) 404 405 go func() { 406 hookChannel <- struct{}{} 407 select { 408 case <-hookChannel: 409 hookChannel <- struct{}{} 410 case <-sigChannel: // Block until SIGINT or SIGTERM received 411 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 412 defer cancel() 413 414 mutex.Lock() 415 sigChannel = nil 416 hook(ctx) 417 mutex.Unlock() 418 } 419 }() 420 } 421 422 // TODO refactor server sartup (use context)