github.com/circl-dev/go-swagger@v0.31.0/generator/templates/server/server.gotmpl (about) 1 // Code generated by go-swagger; DO NOT EDIT. 2 3 4 {{ if .Copyright -}}// {{ comment .Copyright -}}{{ end }} 5 6 7 package {{ .APIPackage }} 8 9 import ( 10 "context" 11 "crypto/tls" 12 "errors" 13 "log" 14 "net" 15 "net/http" 16 "os" 17 "os/signal" 18 "strconv" 19 "sync" 20 "sync/atomic" 21 "syscall" 22 "time" 23 24 "github.com/go-openapi/swag" 25 {{ if .UseGoStructFlags }}flags "github.com/jessevdk/go-flags" 26 {{ end -}} 27 "github.com/circl-dev/runtime/flagext" 28 {{ if .UsePFlags }}flag "github.com/spf13/pflag" 29 {{ end -}} 30 {{ if .UseFlags }}"flag" 31 "strings" 32 {{ end -}} 33 "golang.org/x/net/netutil" 34 35 {{ imports .DefaultImports }} 36 {{ imports .Imports }} 37 ) 38 39 const ( 40 schemeHTTP = "http" 41 schemeHTTPS = "https" 42 schemeUnix = "unix" 43 ) 44 45 var defaultSchemes []string 46 47 func init() { 48 defaultSchemes = []string{ {{ if (hasInsecure .Schemes) }} 49 schemeHTTP,{{ end}}{{ if (hasSecure .Schemes) }} 50 schemeHTTPS,{{ end }}{{ if (contains .ExtraSchemes "unix") }} 51 schemeUnix,{{ end }} 52 } 53 } 54 55 {{ if not .UseGoStructFlags}} 56 var ({{ if .ExcludeSpec }} 57 specFile string 58 {{ end }}enabledListeners []string 59 cleanupTimeout time.Duration 60 gracefulTimeout time.Duration 61 maxHeaderSize flagext.ByteSize 62 63 socketPath string 64 65 host string 66 port int 67 listenLimit int 68 keepAlive time.Duration 69 readTimeout time.Duration 70 writeTimeout time.Duration 71 72 tlsHost string 73 tlsPort int 74 tlsListenLimit int 75 tlsKeepAlive time.Duration 76 tlsReadTimeout time.Duration 77 tlsWriteTimeout time.Duration 78 tlsCertificate string 79 tlsCertificateKey string 80 tlsCACertificate string 81 ) 82 83 {{ if .UseFlags}} 84 // StringSliceVar support for flag 85 type sliceValue []string 86 87 func newSliceValue(vals []string, p *[]string) *sliceValue { 88 *p = vals 89 return (*sliceValue)(p) 90 } 91 92 func (s *sliceValue) Set(val string) error { 93 *s = sliceValue(strings.Split(val, ",")) 94 return nil 95 } 96 97 func (s *sliceValue) Get() interface{} { return []string(*s) } 98 99 func (s *sliceValue) String() string { return strings.Join([]string(*s), ",") } 100 // end StringSliceVar support for flag 101 {{ end }} 102 103 func init() { 104 maxHeaderSize = flagext.ByteSize(1000000){{ if .ExcludeSpec }} 105 flag.StringVarP(&specFile, "spec", "", "", "the swagger specification to serve") 106 {{ end }} 107 {{ if .UseFlags }} 108 flag.Var(newSliceValue(defaultSchemes, &enabledListeners), "schema", "the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec") 109 {{ end }} 110 {{ if .UsePFlags }} 111 flag.StringSliceVar(&enabledListeners, "scheme", defaultSchemes, "the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec") 112 {{ end }} 113 flag.DurationVar(&cleanupTimeout, "cleanup-timeout", 10*time.Second, "grace period for which to wait before killing idle connections") 114 flag.DurationVar(&gracefulTimeout, "graceful-timeout", 15*time.Second, "grace period for which to wait before shutting down the server") 115 flag.Var(&maxHeaderSize, "max-header-size", "controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body") 116 117 flag.StringVar(&socketPath, "socket-path", "/var/run/todo-list.sock", "the unix socket to listen on") 118 119 flag.StringVar(&host, "host", "localhost", "the IP to listen on") 120 flag.IntVar(&port, "port", 0, "the port to listen on for insecure connections, defaults to a random value") 121 flag.IntVar(&listenLimit, "listen-limit", 0, "limit the number of outstanding requests") 122 flag.DurationVar(&keepAlive, "keep-alive", 3*time.Minute, "sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)") 123 flag.DurationVar(&readTimeout, "read-timeout", 30*time.Second, "maximum duration before timing out read of the request") 124 flag.DurationVar(&writeTimeout, "write-timeout", 30*time.Second, "maximum duration before timing out write of the response") 125 126 flag.StringVar(&tlsHost, "tls-host", "localhost", "the IP to listen on") 127 flag.IntVar(&tlsPort, "tls-port", 0, "the port to listen on for secure connections, defaults to a random value") 128 flag.StringVar(&tlsCertificate, "tls-certificate", "", "the certificate file to use for secure connections") 129 flag.StringVar(&tlsCertificateKey, "tls-key", "", "the private key file to use for secure connections (without passphrase)") 130 flag.StringVar(&tlsCACertificate, "tls-ca", "", "the certificate authority certificate file to be used with mutual tls auth") 131 flag.IntVar(&tlsListenLimit, "tls-listen-limit", 0, "limit the number of outstanding requests") 132 flag.DurationVar(&tlsKeepAlive, "tls-keep-alive", 3*time.Minute, "sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)") 133 flag.DurationVar(&tlsReadTimeout, "tls-read-timeout", 30*time.Second, "maximum duration before timing out read of the request") 134 flag.DurationVar(&tlsWriteTimeout, "tls-write-timeout", 30*time.Second, "maximum duration before timing out write of the response") 135 } 136 137 func stringEnvOverride(orig string, def string, keys ...string) string { 138 for _, k := range keys { 139 if os.Getenv(k) != "" { 140 return os.Getenv(k) 141 } 142 } 143 if def != "" && orig == "" { 144 return def 145 } 146 return orig 147 } 148 149 func intEnvOverride(orig int, def int, keys ...string) int { 150 for _, k := range keys { 151 if os.Getenv(k) != "" { 152 v, err := strconv.Atoi(os.Getenv(k)) 153 if err != nil { 154 fmt.Fprintln(os.Stderr, k, "is not a valid number") 155 os.Exit(1) 156 } 157 return v 158 } 159 } 160 if def != 0 && orig == 0 { 161 return def 162 } 163 return orig 164 } 165 {{ end }} 166 167 // NewServer creates a new api {{ humanize .Name }} server but does not configure it 168 func NewServer(api *{{ .APIPackageAlias }}.{{ pascalize .Name }}API) *Server { 169 s := new(Server) 170 {{ if not .UseGoStructFlags }} 171 s.EnabledListeners = enabledListeners 172 s.CleanupTimeout = cleanupTimeout 173 s.GracefulTimeout = gracefulTimeout 174 s.MaxHeaderSize = maxHeaderSize 175 s.SocketPath = socketPath 176 s.Host = stringEnvOverride(host, "", "HOST") 177 s.Port = intEnvOverride(port, 0, "PORT") 178 s.ListenLimit = listenLimit 179 s.KeepAlive = keepAlive 180 s.ReadTimeout = readTimeout 181 s.WriteTimeout = writeTimeout 182 s.TLSHost = stringEnvOverride(tlsHost, s.Host, "TLS_HOST", "HOST") 183 s.TLSPort = intEnvOverride(tlsPort, 0, "TLS_PORT") 184 s.TLSCertificate = stringEnvOverride(tlsCertificate, "", "TLS_CERTIFICATE") 185 s.TLSCertificateKey = stringEnvOverride(tlsCertificateKey, "", "TLS_PRIVATE_KEY") 186 s.TLSCACertificate = stringEnvOverride(tlsCACertificate, "", "TLS_CA_CERTIFICATE") 187 s.TLSListenLimit = tlsListenLimit 188 s.TLSKeepAlive = tlsKeepAlive 189 s.TLSReadTimeout = tlsReadTimeout 190 s.TLSWriteTimeout = tlsWriteTimeout 191 {{- if .ExcludeSpec }} 192 s.Spec = specFile 193 {{- end }} 194 {{- end }} 195 s.shutdown = make(chan struct{}) 196 s.api = api 197 s.interrupt = make(chan os.Signal, 1) 198 return s 199 } 200 201 // ConfigureAPI configures the API and handlers. 202 func (s *Server) ConfigureAPI() { 203 if s.api != nil { 204 s.handler = configureAPI(s.api) 205 } 206 } 207 208 // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse 209 func (s *Server) ConfigureFlags() { 210 if s.api != nil { 211 configureFlags(s.api) 212 } 213 } 214 215 // Server for the {{ humanize .Name }} API 216 type Server struct { 217 EnabledListeners []string{{ if .UseGoStructFlags }} `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"`{{ end }} 218 CleanupTimeout time.Duration{{ if .UseGoStructFlags }} `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"`{{ end }} 219 GracefulTimeout time.Duration{{ if .UseGoStructFlags }} `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"`{{ end }} 220 MaxHeaderSize flagext.ByteSize{{ if .UseGoStructFlags }} `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"`{{ end }} 221 222 SocketPath {{ if not .UseGoStructFlags }}string{{ else }}flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/{{ dasherize .Name }}.sock"`{{ end }} 223 domainSocketL net.Listener 224 225 Host string{{ if .UseGoStructFlags }} `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"`{{ end }} 226 Port int{{ if .UseGoStructFlags }} `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"`{{ end }} 227 ListenLimit int{{ if .UseGoStructFlags }} `long:"listen-limit" description:"limit the number of outstanding requests"`{{ end }} 228 KeepAlive time.Duration{{ if .UseGoStructFlags }} `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"`{{ end }} 229 ReadTimeout time.Duration{{ if .UseGoStructFlags }} `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"`{{ end }} 230 WriteTimeout time.Duration{{ if .UseGoStructFlags }} `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"60s"`{{ end }} 231 httpServerL net.Listener 232 233 TLSHost string{{ if .UseGoStructFlags }} `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"`{{ end }} 234 TLSPort int{{ if .UseGoStructFlags }} `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"`{{ end }} 235 TLSCertificate {{ if not .UseGoStructFlags }}string{{ else }}flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"`{{ end }} 236 TLSCertificateKey {{ if not .UseGoStructFlags }}string{{ else }}flags.Filename `long:"tls-key" description:"the private key to use for secure connections" env:"TLS_PRIVATE_KEY"`{{ end }} 237 TLSCACertificate {{ if not .UseGoStructFlags }}string{{ else }}flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"`{{ end }} 238 TLSListenLimit int{{ if .UseGoStructFlags }} `long:"tls-listen-limit" description:"limit the number of outstanding requests"`{{ end }} 239 TLSKeepAlive time.Duration{{ if .UseGoStructFlags }} `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"`{{ end }} 240 TLSReadTimeout time.Duration{{ if .UseGoStructFlags }} `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"`{{ end }} 241 TLSWriteTimeout time.Duration{{ if .UseGoStructFlags }} `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"`{{ end }} 242 httpsServerL net.Listener 243 244 {{ if .ExcludeSpec }}Spec {{ if not .UseGoStructFlags }}string{{ else }}flags.Filename `long:"spec" description:"the swagger specification to serve"`{{ end }}{{ end }} 245 api *{{ .APIPackageAlias }}.{{ pascalize .Name }}API 246 handler http.Handler 247 hasListeners bool 248 shutdown chan struct{} 249 shuttingDown int32 250 interrupted bool 251 interrupt chan os.Signal 252 } 253 254 // Logf logs message either via defined user logger or via system one if no user logger is defined. 255 func (s *Server) Logf(f string, args ...interface{}) { 256 if s.api != nil && s.api.Logger != nil { 257 s.api.Logger(f, args...) 258 } else { 259 log.Printf(f, args...) 260 } 261 } 262 263 // Fatalf logs message either via defined user logger or via system one if no user logger is defined. 264 // Exits with non-zero status after printing 265 func (s *Server) Fatalf(f string, args ...interface{}) { 266 if s.api != nil && s.api.Logger != nil { 267 s.api.Logger(f, args...) 268 os.Exit(1) 269 } else { 270 log.Fatalf(f, args...) 271 } 272 } 273 274 // SetAPI configures the server with the specified API. Needs to be called before Serve 275 func (s *Server) SetAPI(api *{{ .APIPackageAlias }}.{{ pascalize .Name }}API) { 276 if api == nil { 277 s.api = nil 278 s.handler = nil 279 return 280 } 281 282 s.api = api 283 s.handler = configureAPI(api) 284 } 285 286 func (s *Server) hasScheme(scheme string) bool { 287 schemes := s.EnabledListeners 288 if len(schemes) == 0 { 289 schemes = defaultSchemes 290 } 291 292 for _, v := range schemes { 293 if v == scheme { 294 return true 295 } 296 } 297 return false 298 } 299 300 // Serve the api 301 func (s *Server) Serve() (err error) { 302 if !s.hasListeners { 303 if err = s.Listen(); err != nil { 304 return err 305 } 306 } 307 308 // set default handler, if none is set 309 if s.handler == nil { 310 if s.api == nil { 311 return errors.New("can't create the default handler, as no api is set") 312 } 313 314 s.SetHandler(s.api.Serve(nil)) 315 } 316 317 wg := new(sync.WaitGroup) 318 once := new(sync.Once) 319 signalNotify(s.interrupt) 320 go handleInterrupt(once, s) 321 322 servers := []*http.Server{} 323 324 if s.hasScheme(schemeUnix) { 325 domainSocket := new(http.Server) 326 domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize) 327 domainSocket.Handler = s.handler 328 if int64(s.CleanupTimeout) > 0 { 329 domainSocket.IdleTimeout = s.CleanupTimeout 330 } 331 332 configureServer(domainSocket, "unix", string(s.SocketPath)) 333 334 servers = append(servers, domainSocket) 335 wg.Add(1) 336 s.Logf("Serving {{ humanize .Name }} at unix://%s", s.SocketPath) 337 go func(l net.Listener){ 338 defer wg.Done() 339 if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed { 340 s.Fatalf("%v", err) 341 } 342 s.Logf("Stopped serving {{ humanize .Name }} at unix://%s", s.SocketPath) 343 }(s.domainSocketL) 344 } 345 346 if s.hasScheme(schemeHTTP) { 347 httpServer := new(http.Server) 348 httpServer.MaxHeaderBytes = int(s.MaxHeaderSize) 349 httpServer.ReadTimeout = s.ReadTimeout 350 httpServer.WriteTimeout = s.WriteTimeout 351 httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0) 352 if s.ListenLimit > 0 { 353 s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit) 354 } 355 356 if int64(s.CleanupTimeout) > 0 { 357 httpServer.IdleTimeout = s.CleanupTimeout 358 } 359 360 httpServer.Handler = s.handler 361 362 configureServer(httpServer, "http", s.httpServerL.Addr().String()) 363 364 servers = append(servers, httpServer) 365 wg.Add(1) 366 s.Logf("Serving {{ humanize .Name }} at http://%s", s.httpServerL.Addr()) 367 go func(l net.Listener) { 368 defer wg.Done() 369 if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed { 370 s.Fatalf("%v", err) 371 } 372 s.Logf("Stopped serving {{ humanize .Name }} at http://%s", l.Addr()) 373 }(s.httpServerL) 374 } 375 376 if s.hasScheme(schemeHTTPS) { 377 httpsServer := new(http.Server) 378 httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize) 379 httpsServer.ReadTimeout = s.TLSReadTimeout 380 httpsServer.WriteTimeout = s.TLSWriteTimeout 381 httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0) 382 if s.TLSListenLimit > 0 { 383 s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit) 384 } 385 if int64(s.CleanupTimeout) > 0 { 386 httpsServer.IdleTimeout = s.CleanupTimeout 387 } 388 httpsServer.Handler = s.handler 389 390 // Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go 391 httpsServer.TLSConfig = &tls.Config{ 392 // Causes servers to use Go's default ciphersuite preferences, 393 // which are tuned to avoid attacks. Does nothing on clients. 394 PreferServerCipherSuites: true, 395 // Only use curves which have assembly implementations 396 // https://github.com/golang/go/tree/master/src/crypto/elliptic 397 CurvePreferences: []tls.CurveID{tls.CurveP256}, 398 {{- if .UseModernMode }} 399 // Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility 400 NextProtos: []string{"h2", "http/1.1"}, 401 // https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols 402 MinVersion: tls.VersionTLS12, 403 // These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy 404 CipherSuites: []uint16{ 405 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 406 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 407 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 408 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 409 tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 410 tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 411 }, 412 {{- end }} 413 } 414 415 // build standard config from server options 416 if s.TLSCertificate != "" && s.TLSCertificateKey != "" { 417 httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1) 418 httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair({{ if .UseGoStructFlags }}string({{ end }}s.TLSCertificate{{ if .UseGoStructFlags }}){{ end }}, {{ if .UseGoStructFlags }}string({{ end }}s.TLSCertificateKey{{ if .UseGoStructFlags }}){{ end }}) 419 if err != nil { 420 return err 421 } 422 } 423 424 if s.TLSCACertificate != "" { 425 // include specified CA certificate 426 caCert, caCertErr := ioutil.ReadFile({{ if .UseGoStructFlags }}string({{ end }}s.TLSCACertificate{{ if .UseGoStructFlags }}){{ end }}) 427 if caCertErr != nil { 428 return caCertErr 429 } 430 caCertPool := x509.NewCertPool() 431 ok := caCertPool.AppendCertsFromPEM(caCert) 432 if !ok { 433 return fmt.Errorf("cannot parse CA certificate") 434 } 435 httpsServer.TLSConfig.ClientCAs = caCertPool 436 httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert 437 } 438 439 // call custom TLS configurator 440 configureTLS(httpsServer.TLSConfig) 441 442 if len(httpsServer.TLSConfig.Certificates) == 0 && httpsServer.TLSConfig.GetCertificate == nil { 443 // after standard and custom config are passed, this ends up with no certificate 444 if s.TLSCertificate == "" { 445 if s.TLSCertificateKey == "" { 446 s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified") 447 } 448 s.Fatalf("the required flag `--tls-certificate` was not specified") 449 } 450 if s.TLSCertificateKey == "" { 451 s.Fatalf("the required flag `--tls-key` was not specified") 452 } 453 // this happens with a wrong custom TLS configurator 454 s.Fatalf("no certificate was configured for TLS") 455 } 456 457 configureServer(httpsServer, "https", s.httpsServerL.Addr().String()) 458 459 servers = append(servers, httpsServer) 460 wg.Add(1) 461 s.Logf("Serving {{ humanize .Name }} at https://%s", s.httpsServerL.Addr()) 462 go func(l net.Listener) { 463 defer wg.Done() 464 if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed { 465 s.Fatalf("%v", err) 466 } 467 s.Logf("Stopped serving {{ humanize .Name }} at https://%s", l.Addr()) 468 }(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig)) 469 } 470 471 wg.Add(1) 472 go s.handleShutdown(wg, &servers) 473 474 wg.Wait() 475 return nil 476 } 477 478 // Listen creates the listeners for the server 479 func (s *Server) Listen() error { 480 if s.hasListeners { // already done this 481 return nil 482 } 483 484 if s.hasScheme(schemeHTTPS) { 485 // Use http host if https host wasn't defined 486 if s.TLSHost == "" { 487 s.TLSHost = s.Host 488 } 489 // Use http listen limit if https listen limit wasn't defined 490 if s.TLSListenLimit == 0 { 491 s.TLSListenLimit = s.ListenLimit 492 } 493 // Use http tcp keep alive if https tcp keep alive wasn't defined 494 if int64(s.TLSKeepAlive) == 0 { 495 s.TLSKeepAlive = s.KeepAlive 496 } 497 // Use http read timeout if https read timeout wasn't defined 498 if int64(s.TLSReadTimeout) == 0 { 499 s.TLSReadTimeout = s.ReadTimeout 500 } 501 // Use http write timeout if https write timeout wasn't defined 502 if int64(s.TLSWriteTimeout) == 0 { 503 s.TLSWriteTimeout = s.WriteTimeout 504 } 505 } 506 507 if s.hasScheme(schemeUnix) { 508 domSockListener, err := net.Listen("unix", string(s.SocketPath)) 509 if err != nil { 510 return err 511 } 512 s.domainSocketL = domSockListener 513 } 514 515 if s.hasScheme(schemeHTTP) { 516 listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port))) 517 if err != nil { 518 return err 519 } 520 521 h, p, err := swag.SplitHostPort(listener.Addr().String()) 522 if err != nil { 523 return err 524 } 525 s.Host = h 526 s.Port = p 527 s.httpServerL = listener 528 } 529 530 if s.hasScheme(schemeHTTPS) { 531 tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort))) 532 if err != nil { 533 return err 534 } 535 536 sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String()) 537 if err != nil { 538 return err 539 } 540 s.TLSHost = sh 541 s.TLSPort = sp 542 s.httpsServerL = tlsListener 543 } 544 545 s.hasListeners = true 546 return nil 547 } 548 549 // Shutdown server and clean up resources 550 func (s *Server) Shutdown() error { 551 if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) { 552 close(s.shutdown) 553 } 554 return nil 555 } 556 557 func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) { 558 // wg.Done must occur last, after s.api.ServerShutdown() 559 // (to preserve old behaviour) 560 defer wg.Done() 561 562 <-s.shutdown 563 564 servers := *serversPtr 565 566 ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout) 567 defer cancel() 568 569 // first execute the pre-shutdown hook 570 s.api.PreServerShutdown() 571 572 shutdownChan := make(chan bool) 573 for i := range servers { 574 server := servers[i] 575 go func() { 576 var success bool 577 defer func() { 578 shutdownChan <- success 579 }() 580 if err := server.Shutdown(ctx); err != nil { 581 // Error from closing listeners, or context timeout: 582 s.Logf("HTTP server Shutdown: %v", err) 583 } else { 584 success = true 585 } 586 }() 587 } 588 589 // Wait until all listeners have successfully shut down before calling ServerShutdown 590 success := true 591 for range servers { 592 success = success && <-shutdownChan 593 } 594 if success { 595 s.api.ServerShutdown() 596 } 597 } 598 599 // GetHandler returns a handler useful for testing 600 func (s *Server) GetHandler() http.Handler { 601 return s.handler 602 } 603 604 // SetHandler allows for setting a http handler on this server 605 func (s *Server) SetHandler(handler http.Handler) { 606 s.handler = handler 607 } 608 609 // UnixListener returns the domain socket listener 610 func (s *Server) UnixListener() (net.Listener, error) { 611 if !s.hasListeners { 612 if err := s.Listen(); err != nil { 613 return nil, err 614 } 615 } 616 return s.domainSocketL, nil 617 } 618 619 // HTTPListener returns the http listener 620 func (s *Server) HTTPListener() (net.Listener, error) { 621 if !s.hasListeners { 622 if err := s.Listen(); err != nil { 623 return nil, err 624 } 625 } 626 return s.httpServerL, nil 627 } 628 629 // TLSListener returns the https listener 630 func (s *Server) TLSListener() (net.Listener, error) { 631 if !s.hasListeners { 632 if err := s.Listen(); err != nil { 633 return nil, err 634 } 635 } 636 return s.httpsServerL, nil 637 } 638 639 func handleInterrupt(once *sync.Once, s *Server) { 640 once.Do(func(){ 641 for range s.interrupt { 642 if s.interrupted { 643 s.Logf("Server already shutting down") 644 continue 645 } 646 s.interrupted = true 647 s.Logf("Shutting down... ") 648 if err := s.Shutdown(); err != nil { 649 s.Logf("HTTP server Shutdown: %v", err) 650 } 651 } 652 }) 653 } 654 655 func signalNotify(interrupt chan<- os.Signal) { 656 signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) 657 }