github.com/circl-dev/go-swagger@v0.31.0/examples/task-tracker/restapi/server.go (about)

     1  // Code generated by go-swagger; DO NOT EDIT.
     2  
     3  package restapi
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"errors"
    10  	"fmt"
    11  	"io/ioutil"
    12  	"log"
    13  	"net"
    14  	"net/http"
    15  	"os"
    16  	"os/signal"
    17  	"strconv"
    18  	"sync"
    19  	"sync/atomic"
    20  	"syscall"
    21  	"time"
    22  
    23  	"github.com/go-openapi/swag"
    24  	flags "github.com/jessevdk/go-flags"
    25  	"github.com/circl-dev/runtime/flagext"
    26  	"golang.org/x/net/netutil"
    27  
    28  	"github.com/circl-dev/go-swagger/examples/task-tracker/restapi/operations"
    29  )
    30  
    31  const (
    32  	schemeHTTP  = "http"
    33  	schemeHTTPS = "https"
    34  	schemeUnix  = "unix"
    35  )
    36  
    37  var defaultSchemes []string
    38  
    39  func init() {
    40  	defaultSchemes = []string{
    41  		schemeHTTP,
    42  		schemeHTTPS,
    43  	}
    44  }
    45  
    46  // NewServer creates a new api task tracker server but does not configure it
    47  func NewServer(api *operations.TaskTrackerAPI) *Server {
    48  	s := new(Server)
    49  
    50  	s.shutdown = make(chan struct{})
    51  	s.api = api
    52  	s.interrupt = make(chan os.Signal, 1)
    53  	return s
    54  }
    55  
    56  // ConfigureAPI configures the API and handlers.
    57  func (s *Server) ConfigureAPI() {
    58  	if s.api != nil {
    59  		s.handler = configureAPI(s.api)
    60  	}
    61  }
    62  
    63  // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse
    64  func (s *Server) ConfigureFlags() {
    65  	if s.api != nil {
    66  		configureFlags(s.api)
    67  	}
    68  }
    69  
    70  // Server for the task tracker API
    71  type Server struct {
    72  	EnabledListeners []string         `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"`
    73  	CleanupTimeout   time.Duration    `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"`
    74  	GracefulTimeout  time.Duration    `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"`
    75  	MaxHeaderSize    flagext.ByteSize `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"`
    76  
    77  	SocketPath    flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/task-tracker.sock"`
    78  	domainSocketL net.Listener
    79  
    80  	Host         string        `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"`
    81  	Port         int           `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"`
    82  	ListenLimit  int           `long:"listen-limit" description:"limit the number of outstanding requests"`
    83  	KeepAlive    time.Duration `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"`
    84  	ReadTimeout  time.Duration `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"`
    85  	WriteTimeout time.Duration `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"60s"`
    86  	httpServerL  net.Listener
    87  
    88  	TLSHost           string         `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"`
    89  	TLSPort           int            `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"`
    90  	TLSCertificate    flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"`
    91  	TLSCertificateKey flags.Filename `long:"tls-key" description:"the private key to use for secure connections" env:"TLS_PRIVATE_KEY"`
    92  	TLSCACertificate  flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"`
    93  	TLSListenLimit    int            `long:"tls-listen-limit" description:"limit the number of outstanding requests"`
    94  	TLSKeepAlive      time.Duration  `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"`
    95  	TLSReadTimeout    time.Duration  `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"`
    96  	TLSWriteTimeout   time.Duration  `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"`
    97  	httpsServerL      net.Listener
    98  
    99  	api          *operations.TaskTrackerAPI
   100  	handler      http.Handler
   101  	hasListeners bool
   102  	shutdown     chan struct{}
   103  	shuttingDown int32
   104  	interrupted  bool
   105  	interrupt    chan os.Signal
   106  }
   107  
   108  // Logf logs message either via defined user logger or via system one if no user logger is defined.
   109  func (s *Server) Logf(f string, args ...interface{}) {
   110  	if s.api != nil && s.api.Logger != nil {
   111  		s.api.Logger(f, args...)
   112  	} else {
   113  		log.Printf(f, args...)
   114  	}
   115  }
   116  
   117  // Fatalf logs message either via defined user logger or via system one if no user logger is defined.
   118  // Exits with non-zero status after printing
   119  func (s *Server) Fatalf(f string, args ...interface{}) {
   120  	if s.api != nil && s.api.Logger != nil {
   121  		s.api.Logger(f, args...)
   122  		os.Exit(1)
   123  	} else {
   124  		log.Fatalf(f, args...)
   125  	}
   126  }
   127  
   128  // SetAPI configures the server with the specified API. Needs to be called before Serve
   129  func (s *Server) SetAPI(api *operations.TaskTrackerAPI) {
   130  	if api == nil {
   131  		s.api = nil
   132  		s.handler = nil
   133  		return
   134  	}
   135  
   136  	s.api = api
   137  	s.handler = configureAPI(api)
   138  }
   139  
   140  func (s *Server) hasScheme(scheme string) bool {
   141  	schemes := s.EnabledListeners
   142  	if len(schemes) == 0 {
   143  		schemes = defaultSchemes
   144  	}
   145  
   146  	for _, v := range schemes {
   147  		if v == scheme {
   148  			return true
   149  		}
   150  	}
   151  	return false
   152  }
   153  
   154  // Serve the api
   155  func (s *Server) Serve() (err error) {
   156  	if !s.hasListeners {
   157  		if err = s.Listen(); err != nil {
   158  			return err
   159  		}
   160  	}
   161  
   162  	// set default handler, if none is set
   163  	if s.handler == nil {
   164  		if s.api == nil {
   165  			return errors.New("can't create the default handler, as no api is set")
   166  		}
   167  
   168  		s.SetHandler(s.api.Serve(nil))
   169  	}
   170  
   171  	wg := new(sync.WaitGroup)
   172  	once := new(sync.Once)
   173  	signalNotify(s.interrupt)
   174  	go handleInterrupt(once, s)
   175  
   176  	servers := []*http.Server{}
   177  
   178  	if s.hasScheme(schemeUnix) {
   179  		domainSocket := new(http.Server)
   180  		domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize)
   181  		domainSocket.Handler = s.handler
   182  		if int64(s.CleanupTimeout) > 0 {
   183  			domainSocket.IdleTimeout = s.CleanupTimeout
   184  		}
   185  
   186  		configureServer(domainSocket, "unix", string(s.SocketPath))
   187  
   188  		servers = append(servers, domainSocket)
   189  		wg.Add(1)
   190  		s.Logf("Serving task tracker at unix://%s", s.SocketPath)
   191  		go func(l net.Listener) {
   192  			defer wg.Done()
   193  			if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed {
   194  				s.Fatalf("%v", err)
   195  			}
   196  			s.Logf("Stopped serving task tracker at unix://%s", s.SocketPath)
   197  		}(s.domainSocketL)
   198  	}
   199  
   200  	if s.hasScheme(schemeHTTP) {
   201  		httpServer := new(http.Server)
   202  		httpServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   203  		httpServer.ReadTimeout = s.ReadTimeout
   204  		httpServer.WriteTimeout = s.WriteTimeout
   205  		httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0)
   206  		if s.ListenLimit > 0 {
   207  			s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit)
   208  		}
   209  
   210  		if int64(s.CleanupTimeout) > 0 {
   211  			httpServer.IdleTimeout = s.CleanupTimeout
   212  		}
   213  
   214  		httpServer.Handler = s.handler
   215  
   216  		configureServer(httpServer, "http", s.httpServerL.Addr().String())
   217  
   218  		servers = append(servers, httpServer)
   219  		wg.Add(1)
   220  		s.Logf("Serving task tracker at http://%s", s.httpServerL.Addr())
   221  		go func(l net.Listener) {
   222  			defer wg.Done()
   223  			if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed {
   224  				s.Fatalf("%v", err)
   225  			}
   226  			s.Logf("Stopped serving task tracker at http://%s", l.Addr())
   227  		}(s.httpServerL)
   228  	}
   229  
   230  	if s.hasScheme(schemeHTTPS) {
   231  		httpsServer := new(http.Server)
   232  		httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   233  		httpsServer.ReadTimeout = s.TLSReadTimeout
   234  		httpsServer.WriteTimeout = s.TLSWriteTimeout
   235  		httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0)
   236  		if s.TLSListenLimit > 0 {
   237  			s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit)
   238  		}
   239  		if int64(s.CleanupTimeout) > 0 {
   240  			httpsServer.IdleTimeout = s.CleanupTimeout
   241  		}
   242  		httpsServer.Handler = s.handler
   243  
   244  		// Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go
   245  		httpsServer.TLSConfig = &tls.Config{
   246  			// Causes servers to use Go's default ciphersuite preferences,
   247  			// which are tuned to avoid attacks. Does nothing on clients.
   248  			PreferServerCipherSuites: true,
   249  			// Only use curves which have assembly implementations
   250  			// https://github.com/golang/go/tree/master/src/crypto/elliptic
   251  			CurvePreferences: []tls.CurveID{tls.CurveP256},
   252  			// Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility
   253  			NextProtos: []string{"h2", "http/1.1"},
   254  			// https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols
   255  			MinVersion: tls.VersionTLS12,
   256  			// These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy
   257  			CipherSuites: []uint16{
   258  				tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
   259  				tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
   260  				tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   261  				tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   262  				tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
   263  				tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
   264  			},
   265  		}
   266  
   267  		// build standard config from server options
   268  		if s.TLSCertificate != "" && s.TLSCertificateKey != "" {
   269  			httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1)
   270  			httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(string(s.TLSCertificate), string(s.TLSCertificateKey))
   271  			if err != nil {
   272  				return err
   273  			}
   274  		}
   275  
   276  		if s.TLSCACertificate != "" {
   277  			// include specified CA certificate
   278  			caCert, caCertErr := ioutil.ReadFile(string(s.TLSCACertificate))
   279  			if caCertErr != nil {
   280  				return caCertErr
   281  			}
   282  			caCertPool := x509.NewCertPool()
   283  			ok := caCertPool.AppendCertsFromPEM(caCert)
   284  			if !ok {
   285  				return fmt.Errorf("cannot parse CA certificate")
   286  			}
   287  			httpsServer.TLSConfig.ClientCAs = caCertPool
   288  			httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
   289  		}
   290  
   291  		// call custom TLS configurator
   292  		configureTLS(httpsServer.TLSConfig)
   293  
   294  		if len(httpsServer.TLSConfig.Certificates) == 0 && httpsServer.TLSConfig.GetCertificate == nil {
   295  			// after standard and custom config are passed, this ends up with no certificate
   296  			if s.TLSCertificate == "" {
   297  				if s.TLSCertificateKey == "" {
   298  					s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified")
   299  				}
   300  				s.Fatalf("the required flag `--tls-certificate` was not specified")
   301  			}
   302  			if s.TLSCertificateKey == "" {
   303  				s.Fatalf("the required flag `--tls-key` was not specified")
   304  			}
   305  			// this happens with a wrong custom TLS configurator
   306  			s.Fatalf("no certificate was configured for TLS")
   307  		}
   308  
   309  		configureServer(httpsServer, "https", s.httpsServerL.Addr().String())
   310  
   311  		servers = append(servers, httpsServer)
   312  		wg.Add(1)
   313  		s.Logf("Serving task tracker at https://%s", s.httpsServerL.Addr())
   314  		go func(l net.Listener) {
   315  			defer wg.Done()
   316  			if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed {
   317  				s.Fatalf("%v", err)
   318  			}
   319  			s.Logf("Stopped serving task tracker at https://%s", l.Addr())
   320  		}(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig))
   321  	}
   322  
   323  	wg.Add(1)
   324  	go s.handleShutdown(wg, &servers)
   325  
   326  	wg.Wait()
   327  	return nil
   328  }
   329  
   330  // Listen creates the listeners for the server
   331  func (s *Server) Listen() error {
   332  	if s.hasListeners { // already done this
   333  		return nil
   334  	}
   335  
   336  	if s.hasScheme(schemeHTTPS) {
   337  		// Use http host if https host wasn't defined
   338  		if s.TLSHost == "" {
   339  			s.TLSHost = s.Host
   340  		}
   341  		// Use http listen limit if https listen limit wasn't defined
   342  		if s.TLSListenLimit == 0 {
   343  			s.TLSListenLimit = s.ListenLimit
   344  		}
   345  		// Use http tcp keep alive if https tcp keep alive wasn't defined
   346  		if int64(s.TLSKeepAlive) == 0 {
   347  			s.TLSKeepAlive = s.KeepAlive
   348  		}
   349  		// Use http read timeout if https read timeout wasn't defined
   350  		if int64(s.TLSReadTimeout) == 0 {
   351  			s.TLSReadTimeout = s.ReadTimeout
   352  		}
   353  		// Use http write timeout if https write timeout wasn't defined
   354  		if int64(s.TLSWriteTimeout) == 0 {
   355  			s.TLSWriteTimeout = s.WriteTimeout
   356  		}
   357  	}
   358  
   359  	if s.hasScheme(schemeUnix) {
   360  		domSockListener, err := net.Listen("unix", string(s.SocketPath))
   361  		if err != nil {
   362  			return err
   363  		}
   364  		s.domainSocketL = domSockListener
   365  	}
   366  
   367  	if s.hasScheme(schemeHTTP) {
   368  		listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)))
   369  		if err != nil {
   370  			return err
   371  		}
   372  
   373  		h, p, err := swag.SplitHostPort(listener.Addr().String())
   374  		if err != nil {
   375  			return err
   376  		}
   377  		s.Host = h
   378  		s.Port = p
   379  		s.httpServerL = listener
   380  	}
   381  
   382  	if s.hasScheme(schemeHTTPS) {
   383  		tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort)))
   384  		if err != nil {
   385  			return err
   386  		}
   387  
   388  		sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String())
   389  		if err != nil {
   390  			return err
   391  		}
   392  		s.TLSHost = sh
   393  		s.TLSPort = sp
   394  		s.httpsServerL = tlsListener
   395  	}
   396  
   397  	s.hasListeners = true
   398  	return nil
   399  }
   400  
   401  // Shutdown server and clean up resources
   402  func (s *Server) Shutdown() error {
   403  	if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) {
   404  		close(s.shutdown)
   405  	}
   406  	return nil
   407  }
   408  
   409  func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) {
   410  	// wg.Done must occur last, after s.api.ServerShutdown()
   411  	// (to preserve old behaviour)
   412  	defer wg.Done()
   413  
   414  	<-s.shutdown
   415  
   416  	servers := *serversPtr
   417  
   418  	ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout)
   419  	defer cancel()
   420  
   421  	// first execute the pre-shutdown hook
   422  	s.api.PreServerShutdown()
   423  
   424  	shutdownChan := make(chan bool)
   425  	for i := range servers {
   426  		server := servers[i]
   427  		go func() {
   428  			var success bool
   429  			defer func() {
   430  				shutdownChan <- success
   431  			}()
   432  			if err := server.Shutdown(ctx); err != nil {
   433  				// Error from closing listeners, or context timeout:
   434  				s.Logf("HTTP server Shutdown: %v", err)
   435  			} else {
   436  				success = true
   437  			}
   438  		}()
   439  	}
   440  
   441  	// Wait until all listeners have successfully shut down before calling ServerShutdown
   442  	success := true
   443  	for range servers {
   444  		success = success && <-shutdownChan
   445  	}
   446  	if success {
   447  		s.api.ServerShutdown()
   448  	}
   449  }
   450  
   451  // GetHandler returns a handler useful for testing
   452  func (s *Server) GetHandler() http.Handler {
   453  	return s.handler
   454  }
   455  
   456  // SetHandler allows for setting a http handler on this server
   457  func (s *Server) SetHandler(handler http.Handler) {
   458  	s.handler = handler
   459  }
   460  
   461  // UnixListener returns the domain socket listener
   462  func (s *Server) UnixListener() (net.Listener, error) {
   463  	if !s.hasListeners {
   464  		if err := s.Listen(); err != nil {
   465  			return nil, err
   466  		}
   467  	}
   468  	return s.domainSocketL, nil
   469  }
   470  
   471  // HTTPListener returns the http listener
   472  func (s *Server) HTTPListener() (net.Listener, error) {
   473  	if !s.hasListeners {
   474  		if err := s.Listen(); err != nil {
   475  			return nil, err
   476  		}
   477  	}
   478  	return s.httpServerL, nil
   479  }
   480  
   481  // TLSListener returns the https listener
   482  func (s *Server) TLSListener() (net.Listener, error) {
   483  	if !s.hasListeners {
   484  		if err := s.Listen(); err != nil {
   485  			return nil, err
   486  		}
   487  	}
   488  	return s.httpsServerL, nil
   489  }
   490  
   491  func handleInterrupt(once *sync.Once, s *Server) {
   492  	once.Do(func() {
   493  		for range s.interrupt {
   494  			if s.interrupted {
   495  				s.Logf("Server already shutting down")
   496  				continue
   497  			}
   498  			s.interrupted = true
   499  			s.Logf("Shutting down... ")
   500  			if err := s.Shutdown(); err != nil {
   501  				s.Logf("HTTP server Shutdown: %v", err)
   502  			}
   503  		}
   504  	})
   505  }
   506  
   507  func signalNotify(interrupt chan<- os.Signal) {
   508  	signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
   509  }