github.com/thetreep/go-swagger@v0.0.0-20240223100711-35af64f14f01/generator/templates/server/server.gotmpl (about) 1 // Code generated by go-swagger; DO NOT EDIT. 2 3 4 {{ if .Copyright -}}// {{ comment .Copyright -}}{{ end }} 5 6 7 package {{ .APIPackage }} 8 9 import ( 10 "context" 11 "crypto/tls" 12 "crypto/x509" 13 "errors" 14 "fmt" 15 "io/ioutil" 16 "log" 17 "net" 18 "net/http" 19 "os" 20 "os/signal" 21 "strconv" 22 "sync" 23 "sync/atomic" 24 "syscall" 25 "time" 26 27 "github.com/go-openapi/swag" 28 {{ if .UseGoStructFlags }}flags "github.com/jessevdk/go-flags" 29 {{ end -}} 30 "github.com/go-openapi/runtime/flagext" 31 {{ if .UsePFlags }}flag "github.com/spf13/pflag" 32 {{ end -}} 33 {{ if .UseFlags }}"flag" 34 "strings" 35 {{ end -}} 36 "golang.org/x/net/netutil" 37 38 {{ imports .DefaultImports }} 39 {{ imports .Imports }} 40 ) 41 42 const ( 43 schemeHTTP = "http" 44 schemeHTTPS = "https" 45 schemeUnix = "unix" 46 ) 47 48 var defaultSchemes []string 49 50 func init() { 51 defaultSchemes = []string{ {{ if (hasInsecure .Schemes) }} 52 schemeHTTP,{{ end}}{{ if (hasSecure .Schemes) }} 53 schemeHTTPS,{{ end }}{{ if (contains .ExtraSchemes "unix") }} 54 schemeUnix,{{ end }} 55 } 56 } 57 58 {{ if not .UseGoStructFlags}} 59 var ({{ if .ExcludeSpec }} 60 specFile string 61 {{ end }}enabledListeners []string 62 cleanupTimeout time.Duration 63 gracefulTimeout time.Duration 64 maxHeaderSize flagext.ByteSize 65 66 socketPath string 67 68 host string 69 port int 70 listenLimit int 71 keepAlive time.Duration 72 readTimeout time.Duration 73 writeTimeout time.Duration 74 75 tlsHost string 76 tlsPort int 77 tlsListenLimit int 78 tlsKeepAlive time.Duration 79 tlsReadTimeout time.Duration 80 tlsWriteTimeout time.Duration 81 tlsCertificate string 82 tlsCertificateKey string 83 tlsCACertificate string 84 ) 85 86 {{ if .UseFlags}} 87 // StringSliceVar support for flag 88 type sliceValue []string 89 90 func newSliceValue(vals []string, p *[]string) *sliceValue { 91 *p = vals 92 return (*sliceValue)(p) 93 } 94 95 func (s *sliceValue) Set(val string) error { 96 *s = sliceValue(strings.Split(val, ",")) 97 return nil 98 } 99 100 func (s *sliceValue) Get() interface{} { return []string(*s) } 101 102 func (s *sliceValue) String() string { return strings.Join([]string(*s), ",") } 103 // end StringSliceVar support for flag 104 {{ end }} 105 106 func init() { 107 maxHeaderSize = flagext.ByteSize(1000000){{ if .ExcludeSpec }} 108 flag.StringVarP(&specFile, "spec", "", "", "the swagger specification to serve") 109 {{ end }} 110 {{ if .UseFlags }} 111 flag.Var(newSliceValue(defaultSchemes, &enabledListeners), "schema", "the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec") 112 {{ end }} 113 {{ if .UsePFlags }} 114 flag.StringSliceVar(&enabledListeners, "scheme", defaultSchemes, "the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec") 115 {{ end }} 116 flag.DurationVar(&cleanupTimeout, "cleanup-timeout", 10*time.Second, "grace period for which to wait before killing idle connections") 117 flag.DurationVar(&gracefulTimeout, "graceful-timeout", 15*time.Second, "grace period for which to wait before shutting down the server") 118 flag.Var(&maxHeaderSize, "max-header-size", "controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body") 119 120 flag.StringVar(&socketPath, "socket-path", "/var/run/todo-list.sock", "the unix socket to listen on") 121 122 flag.StringVar(&host, "host", "localhost", "the IP to listen on") 123 flag.IntVar(&port, "port", 0, "the port to listen on for insecure connections, defaults to a random value") 124 flag.IntVar(&listenLimit, "listen-limit", 0, "limit the number of outstanding requests") 125 flag.DurationVar(&keepAlive, "keep-alive", 3*time.Minute, "sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)") 126 flag.DurationVar(&readTimeout, "read-timeout", 30*time.Second, "maximum duration before timing out read of the request") 127 flag.DurationVar(&writeTimeout, "write-timeout", 30*time.Second, "maximum duration before timing out write of the response") 128 129 flag.StringVar(&tlsHost, "tls-host", "localhost", "the IP to listen on") 130 flag.IntVar(&tlsPort, "tls-port", 0, "the port to listen on for secure connections, defaults to a random value") 131 flag.StringVar(&tlsCertificate, "tls-certificate", "", "the certificate file to use for secure connections") 132 flag.StringVar(&tlsCertificateKey, "tls-key", "", "the private key file to use for secure connections (without passphrase)") 133 flag.StringVar(&tlsCACertificate, "tls-ca", "", "the certificate authority certificate file to be used with mutual tls auth") 134 flag.IntVar(&tlsListenLimit, "tls-listen-limit", 0, "limit the number of outstanding requests") 135 flag.DurationVar(&tlsKeepAlive, "tls-keep-alive", 3*time.Minute, "sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)") 136 flag.DurationVar(&tlsReadTimeout, "tls-read-timeout", 30*time.Second, "maximum duration before timing out read of the request") 137 flag.DurationVar(&tlsWriteTimeout, "tls-write-timeout", 30*time.Second, "maximum duration before timing out write of the response") 138 } 139 140 func stringEnvOverride(orig string, def string, keys ...string) string { 141 for _, k := range keys { 142 if os.Getenv(k) != "" { 143 return os.Getenv(k) 144 } 145 } 146 if def != "" && orig == "" { 147 return def 148 } 149 return orig 150 } 151 152 func intEnvOverride(orig int, def int, keys ...string) int { 153 for _, k := range keys { 154 if os.Getenv(k) != "" { 155 v, err := strconv.Atoi(os.Getenv(k)) 156 if err != nil { 157 fmt.Fprintln(os.Stderr, k, "is not a valid number") 158 os.Exit(1) 159 } 160 return v 161 } 162 } 163 if def != 0 && orig == 0 { 164 return def 165 } 166 return orig 167 } 168 {{ end }} 169 170 // NewServer creates a new api {{ humanize .Name }} server but does not configure it 171 func NewServer(api *{{ .APIPackageAlias }}.{{ pascalize .Name }}API) *Server { 172 s := new(Server) 173 {{ if not .UseGoStructFlags }} 174 s.EnabledListeners = enabledListeners 175 s.CleanupTimeout = cleanupTimeout 176 s.GracefulTimeout = gracefulTimeout 177 s.MaxHeaderSize = maxHeaderSize 178 s.SocketPath = socketPath 179 s.Host = stringEnvOverride(host, "", "HOST") 180 s.Port = intEnvOverride(port, 0, "PORT") 181 s.ListenLimit = listenLimit 182 s.KeepAlive = keepAlive 183 s.ReadTimeout = readTimeout 184 s.WriteTimeout = writeTimeout 185 s.TLSHost = stringEnvOverride(tlsHost, s.Host, "TLS_HOST", "HOST") 186 s.TLSPort = intEnvOverride(tlsPort, 0, "TLS_PORT") 187 s.TLSCertificate = stringEnvOverride(tlsCertificate, "", "TLS_CERTIFICATE") 188 s.TLSCertificateKey = stringEnvOverride(tlsCertificateKey, "", "TLS_PRIVATE_KEY") 189 s.TLSCACertificate = stringEnvOverride(tlsCACertificate, "", "TLS_CA_CERTIFICATE") 190 s.TLSListenLimit = tlsListenLimit 191 s.TLSKeepAlive = tlsKeepAlive 192 s.TLSReadTimeout = tlsReadTimeout 193 s.TLSWriteTimeout = tlsWriteTimeout 194 {{- if .ExcludeSpec }} 195 s.Spec = specFile 196 {{- end }} 197 {{- end }} 198 s.shutdown = make(chan struct{}) 199 s.api = api 200 s.interrupt = make(chan os.Signal, 1) 201 return s 202 } 203 204 // ConfigureAPI configures the API and handlers. 205 func (s *Server) ConfigureAPI() { 206 if s.api != nil { 207 s.handler = configureAPI(s.api) 208 } 209 } 210 211 // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse 212 func (s *Server) ConfigureFlags() { 213 if s.api != nil { 214 configureFlags(s.api) 215 } 216 } 217 218 // Server for the {{ humanize .Name }} API 219 type Server struct { 220 EnabledListeners []string{{ if .UseGoStructFlags }} `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"`{{ end }} 221 CleanupTimeout time.Duration{{ if .UseGoStructFlags }} `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"`{{ end }} 222 GracefulTimeout time.Duration{{ if .UseGoStructFlags }} `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"`{{ end }} 223 MaxHeaderSize flagext.ByteSize{{ if .UseGoStructFlags }} `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"`{{ end }} 224 225 SocketPath {{ if not .UseGoStructFlags }}string{{ else }}flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/{{ dasherize .Name }}.sock"`{{ end }} 226 domainSocketL net.Listener 227 228 Host string{{ if .UseGoStructFlags }} `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"`{{ end }} 229 Port int{{ if .UseGoStructFlags }} `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"`{{ end }} 230 ListenLimit int{{ if .UseGoStructFlags }} `long:"listen-limit" description:"limit the number of outstanding requests"`{{ end }} 231 KeepAlive time.Duration{{ if .UseGoStructFlags }} `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"`{{ end }} 232 ReadTimeout time.Duration{{ if .UseGoStructFlags }} `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"`{{ end }} 233 WriteTimeout time.Duration{{ if .UseGoStructFlags }} `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"30s"`{{ end }} 234 httpServerL net.Listener 235 236 TLSHost string{{ if .UseGoStructFlags }} `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"`{{ end }} 237 TLSPort int{{ if .UseGoStructFlags }} `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"`{{ end }} 238 TLSCertificate {{ if not .UseGoStructFlags }}string{{ else }}flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"`{{ end }} 239 TLSCertificateKey {{ if not .UseGoStructFlags }}string{{ else }}flags.Filename `long:"tls-key" description:"the private key to use for secure connections" env:"TLS_PRIVATE_KEY"`{{ end }} 240 TLSCACertificate {{ if not .UseGoStructFlags }}string{{ else }}flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"`{{ end }} 241 TLSListenLimit int{{ if .UseGoStructFlags }} `long:"tls-listen-limit" description:"limit the number of outstanding requests"`{{ end }} 242 TLSKeepAlive time.Duration{{ if .UseGoStructFlags }} `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"`{{ end }} 243 TLSReadTimeout time.Duration{{ if .UseGoStructFlags }} `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"`{{ end }} 244 TLSWriteTimeout time.Duration{{ if .UseGoStructFlags }} `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"`{{ end }} 245 httpsServerL net.Listener 246 247 {{ if .ExcludeSpec }}Spec {{ if not .UseGoStructFlags }}string{{ else }}flags.Filename `long:"spec" description:"the swagger specification to serve"`{{ end }}{{ end }} 248 api *{{ .APIPackageAlias }}.{{ pascalize .Name }}API 249 handler http.Handler 250 hasListeners bool 251 shutdown chan struct{} 252 shuttingDown int32 253 interrupted bool 254 interrupt chan os.Signal 255 } 256 257 // Logf logs message either via defined user logger or via system one if no user logger is defined. 258 func (s *Server) Logf(f string, args ...interface{}) { 259 if s.api != nil && s.api.Logger != nil { 260 s.api.Logger(f, args...) 261 } else { 262 log.Printf(f, args...) 263 } 264 } 265 266 // Fatalf logs message either via defined user logger or via system one if no user logger is defined. 267 // Exits with non-zero status after printing 268 func (s *Server) Fatalf(f string, args ...interface{}) { 269 if s.api != nil && s.api.Logger != nil { 270 s.api.Logger(f, args...) 271 os.Exit(1) 272 } else { 273 log.Fatalf(f, args...) 274 } 275 } 276 277 // SetAPI configures the server with the specified API. Needs to be called before Serve 278 func (s *Server) SetAPI(api *{{ .APIPackageAlias }}.{{ pascalize .Name }}API) { 279 if api == nil { 280 s.api = nil 281 s.handler = nil 282 return 283 } 284 285 s.api = api 286 s.handler = configureAPI(api) 287 } 288 289 func (s *Server) hasScheme(scheme string) bool { 290 schemes := s.EnabledListeners 291 if len(schemes) == 0 { 292 schemes = defaultSchemes 293 } 294 295 for _, v := range schemes { 296 if v == scheme { 297 return true 298 } 299 } 300 return false 301 } 302 303 // Serve the api 304 func (s *Server) Serve() (err error) { 305 if !s.hasListeners { 306 if err = s.Listen(); err != nil { 307 return err 308 } 309 } 310 311 // set default handler, if none is set 312 if s.handler == nil { 313 if s.api == nil { 314 return errors.New("can't create the default handler, as no api is set") 315 } 316 317 s.SetHandler(s.api.Serve(nil)) 318 } 319 320 wg := new(sync.WaitGroup) 321 once := new(sync.Once) 322 signalNotify(s.interrupt) 323 go handleInterrupt(once, s) 324 325 servers := []*http.Server{} 326 327 if s.hasScheme(schemeUnix) { 328 domainSocket := new(http.Server) 329 domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize) 330 domainSocket.Handler = s.handler 331 if int64(s.CleanupTimeout) > 0 { 332 domainSocket.IdleTimeout = s.CleanupTimeout 333 } 334 335 configureServer(domainSocket, "unix", string(s.SocketPath)) 336 337 servers = append(servers, domainSocket) 338 wg.Add(1) 339 s.Logf("Serving {{ humanize .Name }} at unix://%s", s.SocketPath) 340 go func(l net.Listener){ 341 defer wg.Done() 342 if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed { 343 s.Fatalf("%v", err) 344 } 345 s.Logf("Stopped serving {{ humanize .Name }} at unix://%s", s.SocketPath) 346 }(s.domainSocketL) 347 } 348 349 if s.hasScheme(schemeHTTP) { 350 httpServer := new(http.Server) 351 httpServer.MaxHeaderBytes = int(s.MaxHeaderSize) 352 httpServer.ReadTimeout = s.ReadTimeout 353 httpServer.WriteTimeout = s.WriteTimeout 354 httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0) 355 if s.ListenLimit > 0 { 356 s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit) 357 } 358 359 if int64(s.CleanupTimeout) > 0 { 360 httpServer.IdleTimeout = s.CleanupTimeout 361 } 362 363 httpServer.Handler = s.handler 364 365 configureServer(httpServer, "http", s.httpServerL.Addr().String()) 366 367 servers = append(servers, httpServer) 368 wg.Add(1) 369 s.Logf("Serving {{ humanize .Name }} at http://%s", s.httpServerL.Addr()) 370 go func(l net.Listener) { 371 defer wg.Done() 372 if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed { 373 s.Fatalf("%v", err) 374 } 375 s.Logf("Stopped serving {{ humanize .Name }} at http://%s", l.Addr()) 376 }(s.httpServerL) 377 } 378 379 if s.hasScheme(schemeHTTPS) { 380 httpsServer := new(http.Server) 381 httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize) 382 httpsServer.ReadTimeout = s.TLSReadTimeout 383 httpsServer.WriteTimeout = s.TLSWriteTimeout 384 httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0) 385 if s.TLSListenLimit > 0 { 386 s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit) 387 } 388 if int64(s.CleanupTimeout) > 0 { 389 httpsServer.IdleTimeout = s.CleanupTimeout 390 } 391 httpsServer.Handler = s.handler 392 393 // Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go 394 httpsServer.TLSConfig = &tls.Config{ 395 // Causes servers to use Go's default ciphersuite preferences, 396 // which are tuned to avoid attacks. Does nothing on clients. 397 PreferServerCipherSuites: true, 398 // Only use curves which have assembly implementations 399 // https://github.com/golang/go/tree/master/src/crypto/elliptic 400 CurvePreferences: []tls.CurveID{tls.CurveP256}, 401 {{- if .UseModernMode }} 402 // Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility 403 NextProtos: []string{"h2", "http/1.1"}, 404 // https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols 405 MinVersion: tls.VersionTLS12, 406 // These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy 407 CipherSuites: []uint16{ 408 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 409 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 410 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 411 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 412 tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 413 tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 414 }, 415 {{- end }} 416 } 417 418 // build standard config from server options 419 if s.TLSCertificate != "" && s.TLSCertificateKey != "" { 420 httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1) 421 httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair({{ if .UseGoStructFlags }}string({{ end }}s.TLSCertificate{{ if .UseGoStructFlags }}){{ end }}, {{ if .UseGoStructFlags }}string({{ end }}s.TLSCertificateKey{{ if .UseGoStructFlags }}){{ end }}) 422 if err != nil { 423 return err 424 } 425 } 426 427 if s.TLSCACertificate != "" { 428 // include specified CA certificate 429 caCert, caCertErr := os.ReadFile({{ if .UseGoStructFlags }}string({{ end }}s.TLSCACertificate{{ if .UseGoStructFlags }}){{ end }}) 430 if caCertErr != nil { 431 return caCertErr 432 } 433 caCertPool := x509.NewCertPool() 434 ok := caCertPool.AppendCertsFromPEM(caCert) 435 if !ok { 436 return fmt.Errorf("cannot parse CA certificate") 437 } 438 httpsServer.TLSConfig.ClientCAs = caCertPool 439 httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert 440 } 441 442 // call custom TLS configurator 443 configureTLS(httpsServer.TLSConfig) 444 445 if len(httpsServer.TLSConfig.Certificates) == 0 && httpsServer.TLSConfig.GetCertificate == nil { 446 // after standard and custom config are passed, this ends up with no certificate 447 if s.TLSCertificate == "" { 448 if s.TLSCertificateKey == "" { 449 s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified") 450 } 451 s.Fatalf("the required flag `--tls-certificate` was not specified") 452 } 453 if s.TLSCertificateKey == "" { 454 s.Fatalf("the required flag `--tls-key` was not specified") 455 } 456 // this happens with a wrong custom TLS configurator 457 s.Fatalf("no certificate was configured for TLS") 458 } 459 460 configureServer(httpsServer, "https", s.httpsServerL.Addr().String()) 461 462 servers = append(servers, httpsServer) 463 wg.Add(1) 464 s.Logf("Serving {{ humanize .Name }} at https://%s", s.httpsServerL.Addr()) 465 go func(l net.Listener) { 466 defer wg.Done() 467 if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed { 468 s.Fatalf("%v", err) 469 } 470 s.Logf("Stopped serving {{ humanize .Name }} at https://%s", l.Addr()) 471 }(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig)) 472 } 473 474 wg.Add(1) 475 go s.handleShutdown(wg, &servers) 476 477 wg.Wait() 478 return nil 479 } 480 481 // Listen creates the listeners for the server 482 func (s *Server) Listen() error { 483 if s.hasListeners { // already done this 484 return nil 485 } 486 487 if s.hasScheme(schemeHTTPS) { 488 // Use http host if https host wasn't defined 489 if s.TLSHost == "" { 490 s.TLSHost = s.Host 491 } 492 // Use http listen limit if https listen limit wasn't defined 493 if s.TLSListenLimit == 0 { 494 s.TLSListenLimit = s.ListenLimit 495 } 496 // Use http tcp keep alive if https tcp keep alive wasn't defined 497 if int64(s.TLSKeepAlive) == 0 { 498 s.TLSKeepAlive = s.KeepAlive 499 } 500 // Use http read timeout if https read timeout wasn't defined 501 if int64(s.TLSReadTimeout) == 0 { 502 s.TLSReadTimeout = s.ReadTimeout 503 } 504 // Use http write timeout if https write timeout wasn't defined 505 if int64(s.TLSWriteTimeout) == 0 { 506 s.TLSWriteTimeout = s.WriteTimeout 507 } 508 } 509 510 if s.hasScheme(schemeUnix) { 511 domSockListener, err := net.Listen("unix", string(s.SocketPath)) 512 if err != nil { 513 return err 514 } 515 s.domainSocketL = domSockListener 516 } 517 518 if s.hasScheme(schemeHTTP) { 519 listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port))) 520 if err != nil { 521 return err 522 } 523 524 h, p, err := swag.SplitHostPort(listener.Addr().String()) 525 if err != nil { 526 return err 527 } 528 s.Host = h 529 s.Port = p 530 s.httpServerL = listener 531 } 532 533 if s.hasScheme(schemeHTTPS) { 534 tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort))) 535 if err != nil { 536 return err 537 } 538 539 sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String()) 540 if err != nil { 541 return err 542 } 543 s.TLSHost = sh 544 s.TLSPort = sp 545 s.httpsServerL = tlsListener 546 } 547 548 s.hasListeners = true 549 return nil 550 } 551 552 // Shutdown server and clean up resources 553 func (s *Server) Shutdown() error { 554 if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) { 555 close(s.shutdown) 556 } 557 return nil 558 } 559 560 func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) { 561 // wg.Done must occur last, after s.api.ServerShutdown() 562 // (to preserve old behaviour) 563 defer wg.Done() 564 565 <-s.shutdown 566 567 servers := *serversPtr 568 569 ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout) 570 defer cancel() 571 572 // first execute the pre-shutdown hook 573 s.api.PreServerShutdown() 574 575 shutdownChan := make(chan bool) 576 for i := range servers { 577 server := servers[i] 578 go func() { 579 var success bool 580 defer func() { 581 shutdownChan <- success 582 }() 583 if err := server.Shutdown(ctx); err != nil { 584 // Error from closing listeners, or context timeout: 585 s.Logf("HTTP server Shutdown: %v", err) 586 } else { 587 success = true 588 } 589 }() 590 } 591 592 // Wait until all listeners have successfully shut down before calling ServerShutdown 593 success := true 594 for range servers { 595 success = success && <-shutdownChan 596 } 597 if success { 598 s.api.ServerShutdown() 599 } 600 } 601 602 // GetHandler returns a handler useful for testing 603 func (s *Server) GetHandler() http.Handler { 604 return s.handler 605 } 606 607 // SetHandler allows for setting a http handler on this server 608 func (s *Server) SetHandler(handler http.Handler) { 609 s.handler = handler 610 } 611 612 // UnixListener returns the domain socket listener 613 func (s *Server) UnixListener() (net.Listener, error) { 614 if !s.hasListeners { 615 if err := s.Listen(); err != nil { 616 return nil, err 617 } 618 } 619 return s.domainSocketL, nil 620 } 621 622 // HTTPListener returns the http listener 623 func (s *Server) HTTPListener() (net.Listener, error) { 624 if !s.hasListeners { 625 if err := s.Listen(); err != nil { 626 return nil, err 627 } 628 } 629 return s.httpServerL, nil 630 } 631 632 // TLSListener returns the https listener 633 func (s *Server) TLSListener() (net.Listener, error) { 634 if !s.hasListeners { 635 if err := s.Listen(); err != nil { 636 return nil, err 637 } 638 } 639 return s.httpsServerL, nil 640 } 641 642 func handleInterrupt(once *sync.Once, s *Server) { 643 once.Do(func(){ 644 for range s.interrupt { 645 if s.interrupted { 646 s.Logf("Server already shutting down") 647 continue 648 } 649 s.interrupted = true 650 s.Logf("Shutting down... ") 651 if err := s.Shutdown(); err != nil { 652 s.Logf("HTTP server Shutdown: %v", err) 653 } 654 } 655 }) 656 } 657 658 func signalNotify(interrupt chan<- os.Signal) { 659 signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) 660 }