github.com/circl-dev/go-swagger@v0.31.0/examples/todo-list-strict/restapi/server.go (about)

     1  // Code generated by go-swagger; DO NOT EDIT.
     2  
     3  package restapi
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"errors"
    10  	"fmt"
    11  	"io/ioutil"
    12  	"log"
    13  	"net"
    14  	"net/http"
    15  	"os"
    16  	"os/signal"
    17  	"strconv"
    18  	"sync"
    19  	"sync/atomic"
    20  	"syscall"
    21  	"time"
    22  
    23  	"github.com/go-openapi/swag"
    24  	flags "github.com/jessevdk/go-flags"
    25  	"github.com/circl-dev/runtime/flagext"
    26  	"golang.org/x/net/netutil"
    27  
    28  	"github.com/circl-dev/go-swagger/examples/todo-list-strict/restapi/operations"
    29  )
    30  
    31  const (
    32  	schemeHTTP  = "http"
    33  	schemeHTTPS = "https"
    34  	schemeUnix  = "unix"
    35  )
    36  
    37  var defaultSchemes []string
    38  
    39  func init() {
    40  	defaultSchemes = []string{
    41  		schemeHTTP,
    42  		schemeHTTPS,
    43  		schemeUnix,
    44  	}
    45  }
    46  
    47  // NewServer creates a new api todo list server but does not configure it
    48  func NewServer(api *operations.TodoListAPI) *Server {
    49  	s := new(Server)
    50  
    51  	s.shutdown = make(chan struct{})
    52  	s.api = api
    53  	s.interrupt = make(chan os.Signal, 1)
    54  	return s
    55  }
    56  
    57  // ConfigureAPI configures the API and handlers.
    58  func (s *Server) ConfigureAPI() {
    59  	if s.api != nil {
    60  		s.handler = configureAPI(s.api)
    61  	}
    62  }
    63  
    64  // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse
    65  func (s *Server) ConfigureFlags() {
    66  	if s.api != nil {
    67  		configureFlags(s.api)
    68  	}
    69  }
    70  
    71  // Server for the todo list API
    72  type Server struct {
    73  	EnabledListeners []string         `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"`
    74  	CleanupTimeout   time.Duration    `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"`
    75  	GracefulTimeout  time.Duration    `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"`
    76  	MaxHeaderSize    flagext.ByteSize `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"`
    77  
    78  	SocketPath    flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/todo-list.sock"`
    79  	domainSocketL net.Listener
    80  
    81  	Host         string        `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"`
    82  	Port         int           `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"`
    83  	ListenLimit  int           `long:"listen-limit" description:"limit the number of outstanding requests"`
    84  	KeepAlive    time.Duration `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"`
    85  	ReadTimeout  time.Duration `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"`
    86  	WriteTimeout time.Duration `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"60s"`
    87  	httpServerL  net.Listener
    88  
    89  	TLSHost           string         `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"`
    90  	TLSPort           int            `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"`
    91  	TLSCertificate    flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"`
    92  	TLSCertificateKey flags.Filename `long:"tls-key" description:"the private key to use for secure connections" env:"TLS_PRIVATE_KEY"`
    93  	TLSCACertificate  flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"`
    94  	TLSListenLimit    int            `long:"tls-listen-limit" description:"limit the number of outstanding requests"`
    95  	TLSKeepAlive      time.Duration  `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"`
    96  	TLSReadTimeout    time.Duration  `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"`
    97  	TLSWriteTimeout   time.Duration  `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"`
    98  	httpsServerL      net.Listener
    99  
   100  	api          *operations.TodoListAPI
   101  	handler      http.Handler
   102  	hasListeners bool
   103  	shutdown     chan struct{}
   104  	shuttingDown int32
   105  	interrupted  bool
   106  	interrupt    chan os.Signal
   107  }
   108  
   109  // Logf logs message either via defined user logger or via system one if no user logger is defined.
   110  func (s *Server) Logf(f string, args ...interface{}) {
   111  	if s.api != nil && s.api.Logger != nil {
   112  		s.api.Logger(f, args...)
   113  	} else {
   114  		log.Printf(f, args...)
   115  	}
   116  }
   117  
   118  // Fatalf logs message either via defined user logger or via system one if no user logger is defined.
   119  // Exits with non-zero status after printing
   120  func (s *Server) Fatalf(f string, args ...interface{}) {
   121  	if s.api != nil && s.api.Logger != nil {
   122  		s.api.Logger(f, args...)
   123  		os.Exit(1)
   124  	} else {
   125  		log.Fatalf(f, args...)
   126  	}
   127  }
   128  
   129  // SetAPI configures the server with the specified API. Needs to be called before Serve
   130  func (s *Server) SetAPI(api *operations.TodoListAPI) {
   131  	if api == nil {
   132  		s.api = nil
   133  		s.handler = nil
   134  		return
   135  	}
   136  
   137  	s.api = api
   138  	s.handler = configureAPI(api)
   139  }
   140  
   141  func (s *Server) hasScheme(scheme string) bool {
   142  	schemes := s.EnabledListeners
   143  	if len(schemes) == 0 {
   144  		schemes = defaultSchemes
   145  	}
   146  
   147  	for _, v := range schemes {
   148  		if v == scheme {
   149  			return true
   150  		}
   151  	}
   152  	return false
   153  }
   154  
   155  // Serve the api
   156  func (s *Server) Serve() (err error) {
   157  	if !s.hasListeners {
   158  		if err = s.Listen(); err != nil {
   159  			return err
   160  		}
   161  	}
   162  
   163  	// set default handler, if none is set
   164  	if s.handler == nil {
   165  		if s.api == nil {
   166  			return errors.New("can't create the default handler, as no api is set")
   167  		}
   168  
   169  		s.SetHandler(s.api.Serve(nil))
   170  	}
   171  
   172  	wg := new(sync.WaitGroup)
   173  	once := new(sync.Once)
   174  	signalNotify(s.interrupt)
   175  	go handleInterrupt(once, s)
   176  
   177  	servers := []*http.Server{}
   178  
   179  	if s.hasScheme(schemeUnix) {
   180  		domainSocket := new(http.Server)
   181  		domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize)
   182  		domainSocket.Handler = s.handler
   183  		if int64(s.CleanupTimeout) > 0 {
   184  			domainSocket.IdleTimeout = s.CleanupTimeout
   185  		}
   186  
   187  		configureServer(domainSocket, "unix", string(s.SocketPath))
   188  
   189  		servers = append(servers, domainSocket)
   190  		wg.Add(1)
   191  		s.Logf("Serving todo list at unix://%s", s.SocketPath)
   192  		go func(l net.Listener) {
   193  			defer wg.Done()
   194  			if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed {
   195  				s.Fatalf("%v", err)
   196  			}
   197  			s.Logf("Stopped serving todo list at unix://%s", s.SocketPath)
   198  		}(s.domainSocketL)
   199  	}
   200  
   201  	if s.hasScheme(schemeHTTP) {
   202  		httpServer := new(http.Server)
   203  		httpServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   204  		httpServer.ReadTimeout = s.ReadTimeout
   205  		httpServer.WriteTimeout = s.WriteTimeout
   206  		httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0)
   207  		if s.ListenLimit > 0 {
   208  			s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit)
   209  		}
   210  
   211  		if int64(s.CleanupTimeout) > 0 {
   212  			httpServer.IdleTimeout = s.CleanupTimeout
   213  		}
   214  
   215  		httpServer.Handler = s.handler
   216  
   217  		configureServer(httpServer, "http", s.httpServerL.Addr().String())
   218  
   219  		servers = append(servers, httpServer)
   220  		wg.Add(1)
   221  		s.Logf("Serving todo list at http://%s", s.httpServerL.Addr())
   222  		go func(l net.Listener) {
   223  			defer wg.Done()
   224  			if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed {
   225  				s.Fatalf("%v", err)
   226  			}
   227  			s.Logf("Stopped serving todo list at http://%s", l.Addr())
   228  		}(s.httpServerL)
   229  	}
   230  
   231  	if s.hasScheme(schemeHTTPS) {
   232  		httpsServer := new(http.Server)
   233  		httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   234  		httpsServer.ReadTimeout = s.TLSReadTimeout
   235  		httpsServer.WriteTimeout = s.TLSWriteTimeout
   236  		httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0)
   237  		if s.TLSListenLimit > 0 {
   238  			s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit)
   239  		}
   240  		if int64(s.CleanupTimeout) > 0 {
   241  			httpsServer.IdleTimeout = s.CleanupTimeout
   242  		}
   243  		httpsServer.Handler = s.handler
   244  
   245  		// Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go
   246  		httpsServer.TLSConfig = &tls.Config{
   247  			// Causes servers to use Go's default ciphersuite preferences,
   248  			// which are tuned to avoid attacks. Does nothing on clients.
   249  			PreferServerCipherSuites: true,
   250  			// Only use curves which have assembly implementations
   251  			// https://github.com/golang/go/tree/master/src/crypto/elliptic
   252  			CurvePreferences: []tls.CurveID{tls.CurveP256},
   253  			// Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility
   254  			NextProtos: []string{"h2", "http/1.1"},
   255  			// https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols
   256  			MinVersion: tls.VersionTLS12,
   257  			// These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy
   258  			CipherSuites: []uint16{
   259  				tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
   260  				tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
   261  				tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   262  				tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   263  				tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
   264  				tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
   265  			},
   266  		}
   267  
   268  		// build standard config from server options
   269  		if s.TLSCertificate != "" && s.TLSCertificateKey != "" {
   270  			httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1)
   271  			httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(string(s.TLSCertificate), string(s.TLSCertificateKey))
   272  			if err != nil {
   273  				return err
   274  			}
   275  		}
   276  
   277  		if s.TLSCACertificate != "" {
   278  			// include specified CA certificate
   279  			caCert, caCertErr := ioutil.ReadFile(string(s.TLSCACertificate))
   280  			if caCertErr != nil {
   281  				return caCertErr
   282  			}
   283  			caCertPool := x509.NewCertPool()
   284  			ok := caCertPool.AppendCertsFromPEM(caCert)
   285  			if !ok {
   286  				return fmt.Errorf("cannot parse CA certificate")
   287  			}
   288  			httpsServer.TLSConfig.ClientCAs = caCertPool
   289  			httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
   290  		}
   291  
   292  		// call custom TLS configurator
   293  		configureTLS(httpsServer.TLSConfig)
   294  
   295  		if len(httpsServer.TLSConfig.Certificates) == 0 && httpsServer.TLSConfig.GetCertificate == nil {
   296  			// after standard and custom config are passed, this ends up with no certificate
   297  			if s.TLSCertificate == "" {
   298  				if s.TLSCertificateKey == "" {
   299  					s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified")
   300  				}
   301  				s.Fatalf("the required flag `--tls-certificate` was not specified")
   302  			}
   303  			if s.TLSCertificateKey == "" {
   304  				s.Fatalf("the required flag `--tls-key` was not specified")
   305  			}
   306  			// this happens with a wrong custom TLS configurator
   307  			s.Fatalf("no certificate was configured for TLS")
   308  		}
   309  
   310  		configureServer(httpsServer, "https", s.httpsServerL.Addr().String())
   311  
   312  		servers = append(servers, httpsServer)
   313  		wg.Add(1)
   314  		s.Logf("Serving todo list at https://%s", s.httpsServerL.Addr())
   315  		go func(l net.Listener) {
   316  			defer wg.Done()
   317  			if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed {
   318  				s.Fatalf("%v", err)
   319  			}
   320  			s.Logf("Stopped serving todo list at https://%s", l.Addr())
   321  		}(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig))
   322  	}
   323  
   324  	wg.Add(1)
   325  	go s.handleShutdown(wg, &servers)
   326  
   327  	wg.Wait()
   328  	return nil
   329  }
   330  
   331  // Listen creates the listeners for the server
   332  func (s *Server) Listen() error {
   333  	if s.hasListeners { // already done this
   334  		return nil
   335  	}
   336  
   337  	if s.hasScheme(schemeHTTPS) {
   338  		// Use http host if https host wasn't defined
   339  		if s.TLSHost == "" {
   340  			s.TLSHost = s.Host
   341  		}
   342  		// Use http listen limit if https listen limit wasn't defined
   343  		if s.TLSListenLimit == 0 {
   344  			s.TLSListenLimit = s.ListenLimit
   345  		}
   346  		// Use http tcp keep alive if https tcp keep alive wasn't defined
   347  		if int64(s.TLSKeepAlive) == 0 {
   348  			s.TLSKeepAlive = s.KeepAlive
   349  		}
   350  		// Use http read timeout if https read timeout wasn't defined
   351  		if int64(s.TLSReadTimeout) == 0 {
   352  			s.TLSReadTimeout = s.ReadTimeout
   353  		}
   354  		// Use http write timeout if https write timeout wasn't defined
   355  		if int64(s.TLSWriteTimeout) == 0 {
   356  			s.TLSWriteTimeout = s.WriteTimeout
   357  		}
   358  	}
   359  
   360  	if s.hasScheme(schemeUnix) {
   361  		domSockListener, err := net.Listen("unix", string(s.SocketPath))
   362  		if err != nil {
   363  			return err
   364  		}
   365  		s.domainSocketL = domSockListener
   366  	}
   367  
   368  	if s.hasScheme(schemeHTTP) {
   369  		listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)))
   370  		if err != nil {
   371  			return err
   372  		}
   373  
   374  		h, p, err := swag.SplitHostPort(listener.Addr().String())
   375  		if err != nil {
   376  			return err
   377  		}
   378  		s.Host = h
   379  		s.Port = p
   380  		s.httpServerL = listener
   381  	}
   382  
   383  	if s.hasScheme(schemeHTTPS) {
   384  		tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort)))
   385  		if err != nil {
   386  			return err
   387  		}
   388  
   389  		sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String())
   390  		if err != nil {
   391  			return err
   392  		}
   393  		s.TLSHost = sh
   394  		s.TLSPort = sp
   395  		s.httpsServerL = tlsListener
   396  	}
   397  
   398  	s.hasListeners = true
   399  	return nil
   400  }
   401  
   402  // Shutdown server and clean up resources
   403  func (s *Server) Shutdown() error {
   404  	if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) {
   405  		close(s.shutdown)
   406  	}
   407  	return nil
   408  }
   409  
   410  func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) {
   411  	// wg.Done must occur last, after s.api.ServerShutdown()
   412  	// (to preserve old behaviour)
   413  	defer wg.Done()
   414  
   415  	<-s.shutdown
   416  
   417  	servers := *serversPtr
   418  
   419  	ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout)
   420  	defer cancel()
   421  
   422  	// first execute the pre-shutdown hook
   423  	s.api.PreServerShutdown()
   424  
   425  	shutdownChan := make(chan bool)
   426  	for i := range servers {
   427  		server := servers[i]
   428  		go func() {
   429  			var success bool
   430  			defer func() {
   431  				shutdownChan <- success
   432  			}()
   433  			if err := server.Shutdown(ctx); err != nil {
   434  				// Error from closing listeners, or context timeout:
   435  				s.Logf("HTTP server Shutdown: %v", err)
   436  			} else {
   437  				success = true
   438  			}
   439  		}()
   440  	}
   441  
   442  	// Wait until all listeners have successfully shut down before calling ServerShutdown
   443  	success := true
   444  	for range servers {
   445  		success = success && <-shutdownChan
   446  	}
   447  	if success {
   448  		s.api.ServerShutdown()
   449  	}
   450  }
   451  
   452  // GetHandler returns a handler useful for testing
   453  func (s *Server) GetHandler() http.Handler {
   454  	return s.handler
   455  }
   456  
   457  // SetHandler allows for setting a http handler on this server
   458  func (s *Server) SetHandler(handler http.Handler) {
   459  	s.handler = handler
   460  }
   461  
   462  // UnixListener returns the domain socket listener
   463  func (s *Server) UnixListener() (net.Listener, error) {
   464  	if !s.hasListeners {
   465  		if err := s.Listen(); err != nil {
   466  			return nil, err
   467  		}
   468  	}
   469  	return s.domainSocketL, nil
   470  }
   471  
   472  // HTTPListener returns the http listener
   473  func (s *Server) HTTPListener() (net.Listener, error) {
   474  	if !s.hasListeners {
   475  		if err := s.Listen(); err != nil {
   476  			return nil, err
   477  		}
   478  	}
   479  	return s.httpServerL, nil
   480  }
   481  
   482  // TLSListener returns the https listener
   483  func (s *Server) TLSListener() (net.Listener, error) {
   484  	if !s.hasListeners {
   485  		if err := s.Listen(); err != nil {
   486  			return nil, err
   487  		}
   488  	}
   489  	return s.httpsServerL, nil
   490  }
   491  
   492  func handleInterrupt(once *sync.Once, s *Server) {
   493  	once.Do(func() {
   494  		for range s.interrupt {
   495  			if s.interrupted {
   496  				s.Logf("Server already shutting down")
   497  				continue
   498  			}
   499  			s.interrupted = true
   500  			s.Logf("Shutting down... ")
   501  			if err := s.Shutdown(); err != nil {
   502  				s.Logf("HTTP server Shutdown: %v", err)
   503  			}
   504  		}
   505  	})
   506  }
   507  
   508  func signalNotify(interrupt chan<- os.Signal) {
   509  	signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
   510  }