github.com/emreu/go-swagger@v0.22.1/examples/stream-server/restapi/server.go (about)

     1  // Code generated by go-swagger; DO NOT EDIT.
     2  
     3  package restapi
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"errors"
    10  	"fmt"
    11  	"io/ioutil"
    12  	"log"
    13  	"net"
    14  	"net/http"
    15  	"os"
    16  	"os/signal"
    17  	"strconv"
    18  	"sync"
    19  	"sync/atomic"
    20  	"syscall"
    21  	"time"
    22  
    23  	"github.com/go-openapi/runtime/flagext"
    24  	"github.com/go-openapi/swag"
    25  	flags "github.com/jessevdk/go-flags"
    26  	"golang.org/x/net/netutil"
    27  
    28  	"github.com/go-swagger/go-swagger/examples/stream-server/restapi/operations"
    29  )
    30  
    31  const (
    32  	schemeHTTP  = "http"
    33  	schemeHTTPS = "https"
    34  	schemeUnix  = "unix"
    35  )
    36  
    37  var defaultSchemes []string
    38  
    39  func init() {
    40  	defaultSchemes = []string{
    41  		schemeHTTP,
    42  	}
    43  }
    44  
    45  // NewServer creates a new api countdown server but does not configure it
    46  func NewServer(api *operations.CountdownAPI) *Server {
    47  	s := new(Server)
    48  
    49  	s.shutdown = make(chan struct{})
    50  	s.api = api
    51  	s.interrupt = make(chan os.Signal, 1)
    52  	return s
    53  }
    54  
    55  // ConfigureAPI configures the API and handlers.
    56  func (s *Server) ConfigureAPI() {
    57  	if s.api != nil {
    58  		s.handler = configureAPI(s.api)
    59  	}
    60  }
    61  
    62  // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse
    63  func (s *Server) ConfigureFlags() {
    64  	if s.api != nil {
    65  		configureFlags(s.api)
    66  	}
    67  }
    68  
    69  // Server for the countdown API
    70  type Server struct {
    71  	EnabledListeners []string         `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"`
    72  	CleanupTimeout   time.Duration    `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"`
    73  	GracefulTimeout  time.Duration    `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"`
    74  	MaxHeaderSize    flagext.ByteSize `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"`
    75  
    76  	SocketPath    flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/countdown.sock"`
    77  	domainSocketL net.Listener
    78  
    79  	Host         string        `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"`
    80  	Port         int           `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"`
    81  	ListenLimit  int           `long:"listen-limit" description:"limit the number of outstanding requests"`
    82  	KeepAlive    time.Duration `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"`
    83  	ReadTimeout  time.Duration `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"`
    84  	WriteTimeout time.Duration `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"60s"`
    85  	httpServerL  net.Listener
    86  
    87  	TLSHost           string         `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"`
    88  	TLSPort           int            `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"`
    89  	TLSCertificate    flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"`
    90  	TLSCertificateKey flags.Filename `long:"tls-key" description:"the private key to use for secure connections" env:"TLS_PRIVATE_KEY"`
    91  	TLSCACertificate  flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"`
    92  	TLSListenLimit    int            `long:"tls-listen-limit" description:"limit the number of outstanding requests"`
    93  	TLSKeepAlive      time.Duration  `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"`
    94  	TLSReadTimeout    time.Duration  `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"`
    95  	TLSWriteTimeout   time.Duration  `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"`
    96  	httpsServerL      net.Listener
    97  
    98  	api          *operations.CountdownAPI
    99  	handler      http.Handler
   100  	hasListeners bool
   101  	shutdown     chan struct{}
   102  	shuttingDown int32
   103  	interrupted  bool
   104  	interrupt    chan os.Signal
   105  }
   106  
   107  // Logf logs message either via defined user logger or via system one if no user logger is defined.
   108  func (s *Server) Logf(f string, args ...interface{}) {
   109  	if s.api != nil && s.api.Logger != nil {
   110  		s.api.Logger(f, args...)
   111  	} else {
   112  		log.Printf(f, args...)
   113  	}
   114  }
   115  
   116  // Fatalf logs message either via defined user logger or via system one if no user logger is defined.
   117  // Exits with non-zero status after printing
   118  func (s *Server) Fatalf(f string, args ...interface{}) {
   119  	if s.api != nil && s.api.Logger != nil {
   120  		s.api.Logger(f, args...)
   121  		os.Exit(1)
   122  	} else {
   123  		log.Fatalf(f, args...)
   124  	}
   125  }
   126  
   127  // SetAPI configures the server with the specified API. Needs to be called before Serve
   128  func (s *Server) SetAPI(api *operations.CountdownAPI) {
   129  	if api == nil {
   130  		s.api = nil
   131  		s.handler = nil
   132  		return
   133  	}
   134  
   135  	s.api = api
   136  	s.handler = configureAPI(api)
   137  }
   138  
   139  func (s *Server) hasScheme(scheme string) bool {
   140  	schemes := s.EnabledListeners
   141  	if len(schemes) == 0 {
   142  		schemes = defaultSchemes
   143  	}
   144  
   145  	for _, v := range schemes {
   146  		if v == scheme {
   147  			return true
   148  		}
   149  	}
   150  	return false
   151  }
   152  
   153  // Serve the api
   154  func (s *Server) Serve() (err error) {
   155  	if !s.hasListeners {
   156  		if err = s.Listen(); err != nil {
   157  			return err
   158  		}
   159  	}
   160  
   161  	// set default handler, if none is set
   162  	if s.handler == nil {
   163  		if s.api == nil {
   164  			return errors.New("can't create the default handler, as no api is set")
   165  		}
   166  
   167  		s.SetHandler(s.api.Serve(nil))
   168  	}
   169  
   170  	wg := new(sync.WaitGroup)
   171  	once := new(sync.Once)
   172  	signalNotify(s.interrupt)
   173  	go handleInterrupt(once, s)
   174  
   175  	servers := []*http.Server{}
   176  	wg.Add(1)
   177  	go s.handleShutdown(wg, &servers)
   178  
   179  	if s.hasScheme(schemeUnix) {
   180  		domainSocket := new(http.Server)
   181  		domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize)
   182  		domainSocket.Handler = s.handler
   183  		if int64(s.CleanupTimeout) > 0 {
   184  			domainSocket.IdleTimeout = s.CleanupTimeout
   185  		}
   186  
   187  		configureServer(domainSocket, "unix", string(s.SocketPath))
   188  
   189  		servers = append(servers, domainSocket)
   190  		wg.Add(1)
   191  		s.Logf("Serving countdown at unix://%s", s.SocketPath)
   192  		go func(l net.Listener) {
   193  			defer wg.Done()
   194  			if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed {
   195  				s.Fatalf("%v", err)
   196  			}
   197  			s.Logf("Stopped serving countdown at unix://%s", s.SocketPath)
   198  		}(s.domainSocketL)
   199  	}
   200  
   201  	if s.hasScheme(schemeHTTP) {
   202  		httpServer := new(http.Server)
   203  		httpServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   204  		httpServer.ReadTimeout = s.ReadTimeout
   205  		httpServer.WriteTimeout = s.WriteTimeout
   206  		httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0)
   207  		if s.ListenLimit > 0 {
   208  			s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit)
   209  		}
   210  
   211  		if int64(s.CleanupTimeout) > 0 {
   212  			httpServer.IdleTimeout = s.CleanupTimeout
   213  		}
   214  
   215  		httpServer.Handler = s.handler
   216  
   217  		configureServer(httpServer, "http", s.httpServerL.Addr().String())
   218  
   219  		servers = append(servers, httpServer)
   220  		wg.Add(1)
   221  		s.Logf("Serving countdown at http://%s", s.httpServerL.Addr())
   222  		go func(l net.Listener) {
   223  			defer wg.Done()
   224  			if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed {
   225  				s.Fatalf("%v", err)
   226  			}
   227  			s.Logf("Stopped serving countdown at http://%s", l.Addr())
   228  		}(s.httpServerL)
   229  	}
   230  
   231  	if s.hasScheme(schemeHTTPS) {
   232  		httpsServer := new(http.Server)
   233  		httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   234  		httpsServer.ReadTimeout = s.TLSReadTimeout
   235  		httpsServer.WriteTimeout = s.TLSWriteTimeout
   236  		httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0)
   237  		if s.TLSListenLimit > 0 {
   238  			s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit)
   239  		}
   240  		if int64(s.CleanupTimeout) > 0 {
   241  			httpsServer.IdleTimeout = s.CleanupTimeout
   242  		}
   243  		httpsServer.Handler = s.handler
   244  
   245  		// Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go
   246  		httpsServer.TLSConfig = &tls.Config{
   247  			// Causes servers to use Go's default ciphersuite preferences,
   248  			// which are tuned to avoid attacks. Does nothing on clients.
   249  			PreferServerCipherSuites: true,
   250  			// Only use curves which have assembly implementations
   251  			// https://github.com/golang/go/tree/master/src/crypto/elliptic
   252  			CurvePreferences: []tls.CurveID{tls.CurveP256},
   253  			// Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility
   254  			NextProtos: []string{"h2", "http/1.1"},
   255  			// https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols
   256  			MinVersion: tls.VersionTLS12,
   257  			// These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy
   258  			CipherSuites: []uint16{
   259  				tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
   260  				tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
   261  				tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   262  				tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   263  				tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
   264  				tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
   265  			},
   266  		}
   267  
   268  		// build standard config from server options
   269  		if s.TLSCertificate != "" && s.TLSCertificateKey != "" {
   270  			httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1)
   271  			httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(string(s.TLSCertificate), string(s.TLSCertificateKey))
   272  			if err != nil {
   273  				return err
   274  			}
   275  		}
   276  
   277  		if s.TLSCACertificate != "" {
   278  			// include specified CA certificate
   279  			caCert, caCertErr := ioutil.ReadFile(string(s.TLSCACertificate))
   280  			if caCertErr != nil {
   281  				return caCertErr
   282  			}
   283  			caCertPool := x509.NewCertPool()
   284  			ok := caCertPool.AppendCertsFromPEM(caCert)
   285  			if !ok {
   286  				return fmt.Errorf("cannot parse CA certificate")
   287  			}
   288  			httpsServer.TLSConfig.ClientCAs = caCertPool
   289  			httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
   290  		}
   291  
   292  		// call custom TLS configurator
   293  		configureTLS(httpsServer.TLSConfig)
   294  
   295  		if len(httpsServer.TLSConfig.Certificates) == 0 && httpsServer.TLSConfig.GetCertificate == nil {
   296  			// after standard and custom config are passed, this ends up with no certificate
   297  			if s.TLSCertificate == "" {
   298  				if s.TLSCertificateKey == "" {
   299  					s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified")
   300  				}
   301  				s.Fatalf("the required flag `--tls-certificate` was not specified")
   302  			}
   303  			if s.TLSCertificateKey == "" {
   304  				s.Fatalf("the required flag `--tls-key` was not specified")
   305  			}
   306  			// this happens with a wrong custom TLS configurator
   307  			s.Fatalf("no certificate was configured for TLS")
   308  		}
   309  
   310  		// must have at least one certificate or panics
   311  		httpsServer.TLSConfig.BuildNameToCertificate()
   312  
   313  		configureServer(httpsServer, "https", s.httpsServerL.Addr().String())
   314  
   315  		servers = append(servers, httpsServer)
   316  		wg.Add(1)
   317  		s.Logf("Serving countdown at https://%s", s.httpsServerL.Addr())
   318  		go func(l net.Listener) {
   319  			defer wg.Done()
   320  			if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed {
   321  				s.Fatalf("%v", err)
   322  			}
   323  			s.Logf("Stopped serving countdown at https://%s", l.Addr())
   324  		}(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig))
   325  	}
   326  
   327  	wg.Wait()
   328  	return nil
   329  }
   330  
   331  // Listen creates the listeners for the server
   332  func (s *Server) Listen() error {
   333  	if s.hasListeners { // already done this
   334  		return nil
   335  	}
   336  
   337  	if s.hasScheme(schemeHTTPS) {
   338  		// Use http host if https host wasn't defined
   339  		if s.TLSHost == "" {
   340  			s.TLSHost = s.Host
   341  		}
   342  		// Use http listen limit if https listen limit wasn't defined
   343  		if s.TLSListenLimit == 0 {
   344  			s.TLSListenLimit = s.ListenLimit
   345  		}
   346  		// Use http tcp keep alive if https tcp keep alive wasn't defined
   347  		if int64(s.TLSKeepAlive) == 0 {
   348  			s.TLSKeepAlive = s.KeepAlive
   349  		}
   350  		// Use http read timeout if https read timeout wasn't defined
   351  		if int64(s.TLSReadTimeout) == 0 {
   352  			s.TLSReadTimeout = s.ReadTimeout
   353  		}
   354  		// Use http write timeout if https write timeout wasn't defined
   355  		if int64(s.TLSWriteTimeout) == 0 {
   356  			s.TLSWriteTimeout = s.WriteTimeout
   357  		}
   358  	}
   359  
   360  	if s.hasScheme(schemeUnix) {
   361  		domSockListener, err := net.Listen("unix", string(s.SocketPath))
   362  		if err != nil {
   363  			return err
   364  		}
   365  		s.domainSocketL = domSockListener
   366  	}
   367  
   368  	if s.hasScheme(schemeHTTP) {
   369  		listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)))
   370  		if err != nil {
   371  			return err
   372  		}
   373  
   374  		h, p, err := swag.SplitHostPort(listener.Addr().String())
   375  		if err != nil {
   376  			return err
   377  		}
   378  		s.Host = h
   379  		s.Port = p
   380  		s.httpServerL = listener
   381  	}
   382  
   383  	if s.hasScheme(schemeHTTPS) {
   384  		tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort)))
   385  		if err != nil {
   386  			return err
   387  		}
   388  
   389  		sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String())
   390  		if err != nil {
   391  			return err
   392  		}
   393  		s.TLSHost = sh
   394  		s.TLSPort = sp
   395  		s.httpsServerL = tlsListener
   396  	}
   397  
   398  	s.hasListeners = true
   399  	return nil
   400  }
   401  
   402  // Shutdown server and clean up resources
   403  func (s *Server) Shutdown() error {
   404  	if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) {
   405  		close(s.shutdown)
   406  	}
   407  	return nil
   408  }
   409  
   410  func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) {
   411  	// wg.Done must occur last, after s.api.ServerShutdown()
   412  	// (to preserve old behaviour)
   413  	defer wg.Done()
   414  
   415  	<-s.shutdown
   416  
   417  	servers := *serversPtr
   418  
   419  	ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout)
   420  	defer cancel()
   421  
   422  	shutdownChan := make(chan bool)
   423  	for i := range servers {
   424  		server := servers[i]
   425  		go func() {
   426  			var success bool
   427  			defer func() {
   428  				shutdownChan <- success
   429  			}()
   430  			if err := server.Shutdown(ctx); err != nil {
   431  				// Error from closing listeners, or context timeout:
   432  				s.Logf("HTTP server Shutdown: %v", err)
   433  			} else {
   434  				success = true
   435  			}
   436  		}()
   437  	}
   438  
   439  	// Wait until all listeners have successfully shut down before calling ServerShutdown
   440  	success := true
   441  	for range servers {
   442  		success = success && <-shutdownChan
   443  	}
   444  	if success {
   445  		s.api.ServerShutdown()
   446  	}
   447  }
   448  
   449  // GetHandler returns a handler useful for testing
   450  func (s *Server) GetHandler() http.Handler {
   451  	return s.handler
   452  }
   453  
   454  // SetHandler allows for setting a http handler on this server
   455  func (s *Server) SetHandler(handler http.Handler) {
   456  	s.handler = handler
   457  }
   458  
   459  // UnixListener returns the domain socket listener
   460  func (s *Server) UnixListener() (net.Listener, error) {
   461  	if !s.hasListeners {
   462  		if err := s.Listen(); err != nil {
   463  			return nil, err
   464  		}
   465  	}
   466  	return s.domainSocketL, nil
   467  }
   468  
   469  // HTTPListener returns the http listener
   470  func (s *Server) HTTPListener() (net.Listener, error) {
   471  	if !s.hasListeners {
   472  		if err := s.Listen(); err != nil {
   473  			return nil, err
   474  		}
   475  	}
   476  	return s.httpServerL, nil
   477  }
   478  
   479  // TLSListener returns the https listener
   480  func (s *Server) TLSListener() (net.Listener, error) {
   481  	if !s.hasListeners {
   482  		if err := s.Listen(); err != nil {
   483  			return nil, err
   484  		}
   485  	}
   486  	return s.httpsServerL, nil
   487  }
   488  
   489  func handleInterrupt(once *sync.Once, s *Server) {
   490  	once.Do(func() {
   491  		for _ = range s.interrupt {
   492  			if s.interrupted {
   493  				s.Logf("Server already shutting down")
   494  				continue
   495  			}
   496  			s.interrupted = true
   497  			s.Logf("Shutting down... ")
   498  			if err := s.Shutdown(); err != nil {
   499  				s.Logf("HTTP server Shutdown: %v", err)
   500  			}
   501  		}
   502  	})
   503  }
   504  
   505  func signalNotify(interrupt chan<- os.Signal) {
   506  	signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
   507  }