github.com/djarvur/go-swagger@v0.18.0/examples/oauth2/restapi/server.go (about) 1 // Code generated by go-swagger; DO NOT EDIT. 2 3 package restapi 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "errors" 10 "fmt" 11 "io/ioutil" 12 "log" 13 "net" 14 "net/http" 15 "os" 16 "os/signal" 17 "strconv" 18 "sync" 19 "sync/atomic" 20 "syscall" 21 "time" 22 23 "github.com/go-openapi/runtime/flagext" 24 "github.com/go-openapi/swag" 25 flags "github.com/jessevdk/go-flags" 26 "golang.org/x/net/netutil" 27 28 "github.com/go-swagger/go-swagger/examples/oauth2/restapi/operations" 29 ) 30 31 const ( 32 schemeHTTP = "http" 33 schemeHTTPS = "https" 34 schemeUnix = "unix" 35 ) 36 37 var defaultSchemes []string 38 39 func init() { 40 defaultSchemes = []string{ 41 schemeHTTP, 42 } 43 } 44 45 // NewServer creates a new api oauth sample server but does not configure it 46 func NewServer(api *operations.OauthSampleAPI) *Server { 47 s := new(Server) 48 49 s.shutdown = make(chan struct{}) 50 s.api = api 51 s.interrupt = make(chan os.Signal, 1) 52 return s 53 } 54 55 // ConfigureAPI configures the API and handlers. 56 func (s *Server) ConfigureAPI() { 57 if s.api != nil { 58 s.handler = configureAPI(s.api) 59 } 60 } 61 62 // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse 63 func (s *Server) ConfigureFlags() { 64 if s.api != nil { 65 configureFlags(s.api) 66 } 67 } 68 69 // Server for the oauth sample API 70 type Server struct { 71 EnabledListeners []string `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"` 72 CleanupTimeout time.Duration `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"` 73 GracefulTimeout time.Duration `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"` 74 MaxHeaderSize flagext.ByteSize `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"` 75 76 SocketPath flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/oauth-sample.sock"` 77 domainSocketL net.Listener 78 79 Host string `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"` 80 Port int `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"` 81 ListenLimit int `long:"listen-limit" description:"limit the number of outstanding requests"` 82 KeepAlive time.Duration `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"` 83 ReadTimeout time.Duration `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"` 84 WriteTimeout time.Duration `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"60s"` 85 httpServerL net.Listener 86 87 TLSHost string `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"` 88 TLSPort int `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"` 89 TLSCertificate flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"` 90 TLSCertificateKey flags.Filename `long:"tls-key" description:"the private key to use for secure conections" env:"TLS_PRIVATE_KEY"` 91 TLSCACertificate flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"` 92 TLSListenLimit int `long:"tls-listen-limit" description:"limit the number of outstanding requests"` 93 TLSKeepAlive time.Duration `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"` 94 TLSReadTimeout time.Duration `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"` 95 TLSWriteTimeout time.Duration `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"` 96 httpsServerL net.Listener 97 98 api *operations.OauthSampleAPI 99 handler http.Handler 100 hasListeners bool 101 shutdown chan struct{} 102 shuttingDown int32 103 interrupted bool 104 interrupt chan os.Signal 105 } 106 107 // Logf logs message either via defined user logger or via system one if no user logger is defined. 108 func (s *Server) Logf(f string, args ...interface{}) { 109 if s.api != nil && s.api.Logger != nil { 110 s.api.Logger(f, args...) 111 } else { 112 log.Printf(f, args...) 113 } 114 } 115 116 // Fatalf logs message either via defined user logger or via system one if no user logger is defined. 117 // Exits with non-zero status after printing 118 func (s *Server) Fatalf(f string, args ...interface{}) { 119 if s.api != nil && s.api.Logger != nil { 120 s.api.Logger(f, args...) 121 os.Exit(1) 122 } else { 123 log.Fatalf(f, args...) 124 } 125 } 126 127 // SetAPI configures the server with the specified API. Needs to be called before Serve 128 func (s *Server) SetAPI(api *operations.OauthSampleAPI) { 129 if api == nil { 130 s.api = nil 131 s.handler = nil 132 return 133 } 134 135 s.api = api 136 s.api.Logger = log.Printf 137 s.handler = configureAPI(api) 138 } 139 140 func (s *Server) hasScheme(scheme string) bool { 141 schemes := s.EnabledListeners 142 if len(schemes) == 0 { 143 schemes = defaultSchemes 144 } 145 146 for _, v := range schemes { 147 if v == scheme { 148 return true 149 } 150 } 151 return false 152 } 153 154 // Serve the api 155 func (s *Server) Serve() (err error) { 156 if !s.hasListeners { 157 if err = s.Listen(); err != nil { 158 return err 159 } 160 } 161 162 // set default handler, if none is set 163 if s.handler == nil { 164 if s.api == nil { 165 return errors.New("can't create the default handler, as no api is set") 166 } 167 168 s.SetHandler(s.api.Serve(nil)) 169 } 170 171 wg := new(sync.WaitGroup) 172 once := new(sync.Once) 173 signalNotify(s.interrupt) 174 go handleInterrupt(once, s) 175 176 servers := []*http.Server{} 177 wg.Add(1) 178 go s.handleShutdown(wg, &servers) 179 180 if s.hasScheme(schemeUnix) { 181 domainSocket := new(http.Server) 182 domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize) 183 domainSocket.Handler = s.handler 184 if int64(s.CleanupTimeout) > 0 { 185 domainSocket.IdleTimeout = s.CleanupTimeout 186 } 187 188 configureServer(domainSocket, "unix", string(s.SocketPath)) 189 190 servers = append(servers, domainSocket) 191 wg.Add(1) 192 s.Logf("Serving oauth sample at unix://%s", s.SocketPath) 193 go func(l net.Listener) { 194 defer wg.Done() 195 if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed { 196 s.Fatalf("%v", err) 197 } 198 s.Logf("Stopped serving oauth sample at unix://%s", s.SocketPath) 199 }(s.domainSocketL) 200 } 201 202 if s.hasScheme(schemeHTTP) { 203 httpServer := new(http.Server) 204 httpServer.MaxHeaderBytes = int(s.MaxHeaderSize) 205 httpServer.ReadTimeout = s.ReadTimeout 206 httpServer.WriteTimeout = s.WriteTimeout 207 httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0) 208 if s.ListenLimit > 0 { 209 s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit) 210 } 211 212 if int64(s.CleanupTimeout) > 0 { 213 httpServer.IdleTimeout = s.CleanupTimeout 214 } 215 216 httpServer.Handler = s.handler 217 218 configureServer(httpServer, "http", s.httpServerL.Addr().String()) 219 220 servers = append(servers, httpServer) 221 wg.Add(1) 222 s.Logf("Serving oauth sample at http://%s", s.httpServerL.Addr()) 223 go func(l net.Listener) { 224 defer wg.Done() 225 if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed { 226 s.Fatalf("%v", err) 227 } 228 s.Logf("Stopped serving oauth sample at http://%s", l.Addr()) 229 }(s.httpServerL) 230 } 231 232 if s.hasScheme(schemeHTTPS) { 233 httpsServer := new(http.Server) 234 httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize) 235 httpsServer.ReadTimeout = s.TLSReadTimeout 236 httpsServer.WriteTimeout = s.TLSWriteTimeout 237 httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0) 238 if s.TLSListenLimit > 0 { 239 s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit) 240 } 241 if int64(s.CleanupTimeout) > 0 { 242 httpsServer.IdleTimeout = s.CleanupTimeout 243 } 244 httpsServer.Handler = s.handler 245 246 // Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go 247 httpsServer.TLSConfig = &tls.Config{ 248 // Causes servers to use Go's default ciphersuite preferences, 249 // which are tuned to avoid attacks. Does nothing on clients. 250 PreferServerCipherSuites: true, 251 // Only use curves which have assembly implementations 252 // https://github.com/golang/go/tree/master/src/crypto/elliptic 253 CurvePreferences: []tls.CurveID{tls.CurveP256}, 254 // Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility 255 NextProtos: []string{"http/1.1", "h2"}, 256 // https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols 257 MinVersion: tls.VersionTLS12, 258 // These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy 259 CipherSuites: []uint16{ 260 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 261 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 262 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 263 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 264 tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 265 tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 266 }, 267 } 268 269 // build standard config from server options 270 if s.TLSCertificate != "" && s.TLSCertificateKey != "" { 271 httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1) 272 httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(string(s.TLSCertificate), string(s.TLSCertificateKey)) 273 if err != nil { 274 return err 275 } 276 } 277 278 if s.TLSCACertificate != "" { 279 // include specified CA certificate 280 caCert, caCertErr := ioutil.ReadFile(string(s.TLSCACertificate)) 281 if caCertErr != nil { 282 return caCertErr 283 } 284 caCertPool := x509.NewCertPool() 285 ok := caCertPool.AppendCertsFromPEM(caCert) 286 if !ok { 287 return fmt.Errorf("cannot parse CA certificate") 288 } 289 httpsServer.TLSConfig.ClientCAs = caCertPool 290 httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert 291 } 292 293 // call custom TLS configurator 294 configureTLS(httpsServer.TLSConfig) 295 296 if len(httpsServer.TLSConfig.Certificates) == 0 { 297 // after standard and custom config are passed, this ends up with no certificate 298 if s.TLSCertificate == "" { 299 if s.TLSCertificateKey == "" { 300 s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified") 301 } 302 s.Fatalf("the required flag `--tls-certificate` was not specified") 303 } 304 if s.TLSCertificateKey == "" { 305 s.Fatalf("the required flag `--tls-key` was not specified") 306 } 307 // this happens with a wrong custom TLS configurator 308 s.Fatalf("no certificate was configured for TLS") 309 } 310 311 // must have at least one certificate or panics 312 httpsServer.TLSConfig.BuildNameToCertificate() 313 314 configureServer(httpsServer, "https", s.httpsServerL.Addr().String()) 315 316 servers = append(servers, httpsServer) 317 wg.Add(1) 318 s.Logf("Serving oauth sample at https://%s", s.httpsServerL.Addr()) 319 go func(l net.Listener) { 320 defer wg.Done() 321 if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed { 322 s.Fatalf("%v", err) 323 } 324 s.Logf("Stopped serving oauth sample at https://%s", l.Addr()) 325 }(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig)) 326 } 327 328 wg.Wait() 329 return nil 330 } 331 332 // Listen creates the listeners for the server 333 func (s *Server) Listen() error { 334 if s.hasListeners { // already done this 335 return nil 336 } 337 338 if s.hasScheme(schemeHTTPS) { 339 // Use http host if https host wasn't defined 340 if s.TLSHost == "" { 341 s.TLSHost = s.Host 342 } 343 // Use http listen limit if https listen limit wasn't defined 344 if s.TLSListenLimit == 0 { 345 s.TLSListenLimit = s.ListenLimit 346 } 347 // Use http tcp keep alive if https tcp keep alive wasn't defined 348 if int64(s.TLSKeepAlive) == 0 { 349 s.TLSKeepAlive = s.KeepAlive 350 } 351 // Use http read timeout if https read timeout wasn't defined 352 if int64(s.TLSReadTimeout) == 0 { 353 s.TLSReadTimeout = s.ReadTimeout 354 } 355 // Use http write timeout if https write timeout wasn't defined 356 if int64(s.TLSWriteTimeout) == 0 { 357 s.TLSWriteTimeout = s.WriteTimeout 358 } 359 } 360 361 if s.hasScheme(schemeUnix) { 362 domSockListener, err := net.Listen("unix", string(s.SocketPath)) 363 if err != nil { 364 return err 365 } 366 s.domainSocketL = domSockListener 367 } 368 369 if s.hasScheme(schemeHTTP) { 370 listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port))) 371 if err != nil { 372 return err 373 } 374 375 h, p, err := swag.SplitHostPort(listener.Addr().String()) 376 if err != nil { 377 return err 378 } 379 s.Host = h 380 s.Port = p 381 s.httpServerL = listener 382 } 383 384 if s.hasScheme(schemeHTTPS) { 385 tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort))) 386 if err != nil { 387 return err 388 } 389 390 sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String()) 391 if err != nil { 392 return err 393 } 394 s.TLSHost = sh 395 s.TLSPort = sp 396 s.httpsServerL = tlsListener 397 } 398 399 s.hasListeners = true 400 return nil 401 } 402 403 // Shutdown server and clean up resources 404 func (s *Server) Shutdown() error { 405 if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) { 406 close(s.shutdown) 407 } 408 return nil 409 } 410 411 func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) { 412 // wg.Done must occur last, after s.api.ServerShutdown() 413 // (to preserve old behaviour) 414 defer wg.Done() 415 416 <-s.shutdown 417 418 servers := *serversPtr 419 420 ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout) 421 defer cancel() 422 423 shutdownChan := make(chan bool) 424 for i := range servers { 425 server := servers[i] 426 go func() { 427 var success bool 428 defer func() { 429 shutdownChan <- success 430 }() 431 if err := server.Shutdown(ctx); err != nil { 432 // Error from closing listeners, or context timeout: 433 s.Logf("HTTP server Shutdown: %v", err) 434 } else { 435 success = true 436 } 437 }() 438 } 439 440 // Wait until all listeners have successfully shut down before calling ServerShutdown 441 success := true 442 for range servers { 443 success = success && <-shutdownChan 444 } 445 if success { 446 s.api.ServerShutdown() 447 } 448 } 449 450 // GetHandler returns a handler useful for testing 451 func (s *Server) GetHandler() http.Handler { 452 return s.handler 453 } 454 455 // SetHandler allows for setting a http handler on this server 456 func (s *Server) SetHandler(handler http.Handler) { 457 s.handler = handler 458 } 459 460 // UnixListener returns the domain socket listener 461 func (s *Server) UnixListener() (net.Listener, error) { 462 if !s.hasListeners { 463 if err := s.Listen(); err != nil { 464 return nil, err 465 } 466 } 467 return s.domainSocketL, nil 468 } 469 470 // HTTPListener returns the http listener 471 func (s *Server) HTTPListener() (net.Listener, error) { 472 if !s.hasListeners { 473 if err := s.Listen(); err != nil { 474 return nil, err 475 } 476 } 477 return s.httpServerL, nil 478 } 479 480 // TLSListener returns the https listener 481 func (s *Server) TLSListener() (net.Listener, error) { 482 if !s.hasListeners { 483 if err := s.Listen(); err != nil { 484 return nil, err 485 } 486 } 487 return s.httpsServerL, nil 488 } 489 490 func handleInterrupt(once *sync.Once, s *Server) { 491 once.Do(func() { 492 for _ = range s.interrupt { 493 if s.interrupted { 494 s.Logf("Server already shutting down") 495 continue 496 } 497 s.interrupted = true 498 s.Logf("Shutting down... ") 499 if err := s.Shutdown(); err != nil { 500 s.Logf("HTTP server Shutdown: %v", err) 501 } 502 } 503 }) 504 } 505 506 func signalNotify(interrupt chan<- os.Signal) { 507 signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) 508 }