github.com/dirkm/go-swagger@v0.19.0/examples/task-tracker/restapi/server.go (about)

     1  // Code generated by go-swagger; DO NOT EDIT.
     2  
     3  package restapi
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"errors"
    10  	"fmt"
    11  	"io/ioutil"
    12  	"log"
    13  	"net"
    14  	"net/http"
    15  	"os"
    16  	"os/signal"
    17  	"strconv"
    18  	"sync"
    19  	"sync/atomic"
    20  	"syscall"
    21  	"time"
    22  
    23  	"github.com/go-openapi/runtime/flagext"
    24  	"github.com/go-openapi/swag"
    25  	flags "github.com/jessevdk/go-flags"
    26  	"golang.org/x/net/netutil"
    27  
    28  	"github.com/go-swagger/go-swagger/examples/task-tracker/restapi/operations"
    29  )
    30  
    31  const (
    32  	schemeHTTP  = "http"
    33  	schemeHTTPS = "https"
    34  	schemeUnix  = "unix"
    35  )
    36  
    37  var defaultSchemes []string
    38  
    39  func init() {
    40  	defaultSchemes = []string{
    41  		schemeHTTP,
    42  		schemeHTTPS,
    43  	}
    44  }
    45  
    46  // NewServer creates a new api task tracker server but does not configure it
    47  func NewServer(api *operations.TaskTrackerAPI) *Server {
    48  	s := new(Server)
    49  
    50  	s.shutdown = make(chan struct{})
    51  	s.api = api
    52  	s.interrupt = make(chan os.Signal, 1)
    53  	return s
    54  }
    55  
    56  // ConfigureAPI configures the API and handlers.
    57  func (s *Server) ConfigureAPI() {
    58  	if s.api != nil {
    59  		s.handler = configureAPI(s.api)
    60  	}
    61  }
    62  
    63  // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse
    64  func (s *Server) ConfigureFlags() {
    65  	if s.api != nil {
    66  		configureFlags(s.api)
    67  	}
    68  }
    69  
    70  // Server for the task tracker API
    71  type Server struct {
    72  	EnabledListeners []string         `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"`
    73  	CleanupTimeout   time.Duration    `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"`
    74  	GracefulTimeout  time.Duration    `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"`
    75  	MaxHeaderSize    flagext.ByteSize `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"`
    76  
    77  	SocketPath    flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/task-tracker.sock"`
    78  	domainSocketL net.Listener
    79  
    80  	Host         string        `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"`
    81  	Port         int           `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"`
    82  	ListenLimit  int           `long:"listen-limit" description:"limit the number of outstanding requests"`
    83  	KeepAlive    time.Duration `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"`
    84  	ReadTimeout  time.Duration `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"`
    85  	WriteTimeout time.Duration `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"60s"`
    86  	httpServerL  net.Listener
    87  
    88  	TLSHost           string         `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"`
    89  	TLSPort           int            `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"`
    90  	TLSCertificate    flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"`
    91  	TLSCertificateKey flags.Filename `long:"tls-key" description:"the private key to use for secure connections" env:"TLS_PRIVATE_KEY"`
    92  	TLSCACertificate  flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"`
    93  	TLSListenLimit    int            `long:"tls-listen-limit" description:"limit the number of outstanding requests"`
    94  	TLSKeepAlive      time.Duration  `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"`
    95  	TLSReadTimeout    time.Duration  `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"`
    96  	TLSWriteTimeout   time.Duration  `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"`
    97  	httpsServerL      net.Listener
    98  
    99  	api          *operations.TaskTrackerAPI
   100  	handler      http.Handler
   101  	hasListeners bool
   102  	shutdown     chan struct{}
   103  	shuttingDown int32
   104  	interrupted  bool
   105  	interrupt    chan os.Signal
   106  }
   107  
   108  // Logf logs message either via defined user logger or via system one if no user logger is defined.
   109  func (s *Server) Logf(f string, args ...interface{}) {
   110  	if s.api != nil && s.api.Logger != nil {
   111  		s.api.Logger(f, args...)
   112  	} else {
   113  		log.Printf(f, args...)
   114  	}
   115  }
   116  
   117  // Fatalf logs message either via defined user logger or via system one if no user logger is defined.
   118  // Exits with non-zero status after printing
   119  func (s *Server) Fatalf(f string, args ...interface{}) {
   120  	if s.api != nil && s.api.Logger != nil {
   121  		s.api.Logger(f, args...)
   122  		os.Exit(1)
   123  	} else {
   124  		log.Fatalf(f, args...)
   125  	}
   126  }
   127  
   128  // SetAPI configures the server with the specified API. Needs to be called before Serve
   129  func (s *Server) SetAPI(api *operations.TaskTrackerAPI) {
   130  	if api == nil {
   131  		s.api = nil
   132  		s.handler = nil
   133  		return
   134  	}
   135  
   136  	s.api = api
   137  	s.api.Logger = log.Printf
   138  	s.handler = configureAPI(api)
   139  }
   140  
   141  func (s *Server) hasScheme(scheme string) bool {
   142  	schemes := s.EnabledListeners
   143  	if len(schemes) == 0 {
   144  		schemes = defaultSchemes
   145  	}
   146  
   147  	for _, v := range schemes {
   148  		if v == scheme {
   149  			return true
   150  		}
   151  	}
   152  	return false
   153  }
   154  
   155  // Serve the api
   156  func (s *Server) Serve() (err error) {
   157  	if !s.hasListeners {
   158  		if err = s.Listen(); err != nil {
   159  			return err
   160  		}
   161  	}
   162  
   163  	// set default handler, if none is set
   164  	if s.handler == nil {
   165  		if s.api == nil {
   166  			return errors.New("can't create the default handler, as no api is set")
   167  		}
   168  
   169  		s.SetHandler(s.api.Serve(nil))
   170  	}
   171  
   172  	wg := new(sync.WaitGroup)
   173  	once := new(sync.Once)
   174  	signalNotify(s.interrupt)
   175  	go handleInterrupt(once, s)
   176  
   177  	servers := []*http.Server{}
   178  	wg.Add(1)
   179  	go s.handleShutdown(wg, &servers)
   180  
   181  	if s.hasScheme(schemeUnix) {
   182  		domainSocket := new(http.Server)
   183  		domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize)
   184  		domainSocket.Handler = s.handler
   185  		if int64(s.CleanupTimeout) > 0 {
   186  			domainSocket.IdleTimeout = s.CleanupTimeout
   187  		}
   188  
   189  		configureServer(domainSocket, "unix", string(s.SocketPath))
   190  
   191  		servers = append(servers, domainSocket)
   192  		wg.Add(1)
   193  		s.Logf("Serving task tracker at unix://%s", s.SocketPath)
   194  		go func(l net.Listener) {
   195  			defer wg.Done()
   196  			if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed {
   197  				s.Fatalf("%v", err)
   198  			}
   199  			s.Logf("Stopped serving task tracker at unix://%s", s.SocketPath)
   200  		}(s.domainSocketL)
   201  	}
   202  
   203  	if s.hasScheme(schemeHTTP) {
   204  		httpServer := new(http.Server)
   205  		httpServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   206  		httpServer.ReadTimeout = s.ReadTimeout
   207  		httpServer.WriteTimeout = s.WriteTimeout
   208  		httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0)
   209  		if s.ListenLimit > 0 {
   210  			s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit)
   211  		}
   212  
   213  		if int64(s.CleanupTimeout) > 0 {
   214  			httpServer.IdleTimeout = s.CleanupTimeout
   215  		}
   216  
   217  		httpServer.Handler = s.handler
   218  
   219  		configureServer(httpServer, "http", s.httpServerL.Addr().String())
   220  
   221  		servers = append(servers, httpServer)
   222  		wg.Add(1)
   223  		s.Logf("Serving task tracker at http://%s", s.httpServerL.Addr())
   224  		go func(l net.Listener) {
   225  			defer wg.Done()
   226  			if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed {
   227  				s.Fatalf("%v", err)
   228  			}
   229  			s.Logf("Stopped serving task tracker at http://%s", l.Addr())
   230  		}(s.httpServerL)
   231  	}
   232  
   233  	if s.hasScheme(schemeHTTPS) {
   234  		httpsServer := new(http.Server)
   235  		httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   236  		httpsServer.ReadTimeout = s.TLSReadTimeout
   237  		httpsServer.WriteTimeout = s.TLSWriteTimeout
   238  		httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0)
   239  		if s.TLSListenLimit > 0 {
   240  			s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit)
   241  		}
   242  		if int64(s.CleanupTimeout) > 0 {
   243  			httpsServer.IdleTimeout = s.CleanupTimeout
   244  		}
   245  		httpsServer.Handler = s.handler
   246  
   247  		// Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go
   248  		httpsServer.TLSConfig = &tls.Config{
   249  			// Causes servers to use Go's default ciphersuite preferences,
   250  			// which are tuned to avoid attacks. Does nothing on clients.
   251  			PreferServerCipherSuites: true,
   252  			// Only use curves which have assembly implementations
   253  			// https://github.com/golang/go/tree/master/src/crypto/elliptic
   254  			CurvePreferences: []tls.CurveID{tls.CurveP256},
   255  			// Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility
   256  			NextProtos: []string{"http/1.1", "h2"},
   257  			// https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols
   258  			MinVersion: tls.VersionTLS12,
   259  			// These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy
   260  			CipherSuites: []uint16{
   261  				tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
   262  				tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
   263  				tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   264  				tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   265  				tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
   266  				tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
   267  			},
   268  		}
   269  
   270  		// build standard config from server options
   271  		if s.TLSCertificate != "" && s.TLSCertificateKey != "" {
   272  			httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1)
   273  			httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(string(s.TLSCertificate), string(s.TLSCertificateKey))
   274  			if err != nil {
   275  				return err
   276  			}
   277  		}
   278  
   279  		if s.TLSCACertificate != "" {
   280  			// include specified CA certificate
   281  			caCert, caCertErr := ioutil.ReadFile(string(s.TLSCACertificate))
   282  			if caCertErr != nil {
   283  				return caCertErr
   284  			}
   285  			caCertPool := x509.NewCertPool()
   286  			ok := caCertPool.AppendCertsFromPEM(caCert)
   287  			if !ok {
   288  				return fmt.Errorf("cannot parse CA certificate")
   289  			}
   290  			httpsServer.TLSConfig.ClientCAs = caCertPool
   291  			httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
   292  		}
   293  
   294  		// call custom TLS configurator
   295  		configureTLS(httpsServer.TLSConfig)
   296  
   297  		if len(httpsServer.TLSConfig.Certificates) == 0 {
   298  			// after standard and custom config are passed, this ends up with no certificate
   299  			if s.TLSCertificate == "" {
   300  				if s.TLSCertificateKey == "" {
   301  					s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified")
   302  				}
   303  				s.Fatalf("the required flag `--tls-certificate` was not specified")
   304  			}
   305  			if s.TLSCertificateKey == "" {
   306  				s.Fatalf("the required flag `--tls-key` was not specified")
   307  			}
   308  			// this happens with a wrong custom TLS configurator
   309  			s.Fatalf("no certificate was configured for TLS")
   310  		}
   311  
   312  		// must have at least one certificate or panics
   313  		httpsServer.TLSConfig.BuildNameToCertificate()
   314  
   315  		configureServer(httpsServer, "https", s.httpsServerL.Addr().String())
   316  
   317  		servers = append(servers, httpsServer)
   318  		wg.Add(1)
   319  		s.Logf("Serving task tracker at https://%s", s.httpsServerL.Addr())
   320  		go func(l net.Listener) {
   321  			defer wg.Done()
   322  			if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed {
   323  				s.Fatalf("%v", err)
   324  			}
   325  			s.Logf("Stopped serving task tracker at https://%s", l.Addr())
   326  		}(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig))
   327  	}
   328  
   329  	wg.Wait()
   330  	return nil
   331  }
   332  
   333  // Listen creates the listeners for the server
   334  func (s *Server) Listen() error {
   335  	if s.hasListeners { // already done this
   336  		return nil
   337  	}
   338  
   339  	if s.hasScheme(schemeHTTPS) {
   340  		// Use http host if https host wasn't defined
   341  		if s.TLSHost == "" {
   342  			s.TLSHost = s.Host
   343  		}
   344  		// Use http listen limit if https listen limit wasn't defined
   345  		if s.TLSListenLimit == 0 {
   346  			s.TLSListenLimit = s.ListenLimit
   347  		}
   348  		// Use http tcp keep alive if https tcp keep alive wasn't defined
   349  		if int64(s.TLSKeepAlive) == 0 {
   350  			s.TLSKeepAlive = s.KeepAlive
   351  		}
   352  		// Use http read timeout if https read timeout wasn't defined
   353  		if int64(s.TLSReadTimeout) == 0 {
   354  			s.TLSReadTimeout = s.ReadTimeout
   355  		}
   356  		// Use http write timeout if https write timeout wasn't defined
   357  		if int64(s.TLSWriteTimeout) == 0 {
   358  			s.TLSWriteTimeout = s.WriteTimeout
   359  		}
   360  	}
   361  
   362  	if s.hasScheme(schemeUnix) {
   363  		domSockListener, err := net.Listen("unix", string(s.SocketPath))
   364  		if err != nil {
   365  			return err
   366  		}
   367  		s.domainSocketL = domSockListener
   368  	}
   369  
   370  	if s.hasScheme(schemeHTTP) {
   371  		listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)))
   372  		if err != nil {
   373  			return err
   374  		}
   375  
   376  		h, p, err := swag.SplitHostPort(listener.Addr().String())
   377  		if err != nil {
   378  			return err
   379  		}
   380  		s.Host = h
   381  		s.Port = p
   382  		s.httpServerL = listener
   383  	}
   384  
   385  	if s.hasScheme(schemeHTTPS) {
   386  		tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort)))
   387  		if err != nil {
   388  			return err
   389  		}
   390  
   391  		sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String())
   392  		if err != nil {
   393  			return err
   394  		}
   395  		s.TLSHost = sh
   396  		s.TLSPort = sp
   397  		s.httpsServerL = tlsListener
   398  	}
   399  
   400  	s.hasListeners = true
   401  	return nil
   402  }
   403  
   404  // Shutdown server and clean up resources
   405  func (s *Server) Shutdown() error {
   406  	if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) {
   407  		close(s.shutdown)
   408  	}
   409  	return nil
   410  }
   411  
   412  func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) {
   413  	// wg.Done must occur last, after s.api.ServerShutdown()
   414  	// (to preserve old behaviour)
   415  	defer wg.Done()
   416  
   417  	<-s.shutdown
   418  
   419  	servers := *serversPtr
   420  
   421  	ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout)
   422  	defer cancel()
   423  
   424  	shutdownChan := make(chan bool)
   425  	for i := range servers {
   426  		server := servers[i]
   427  		go func() {
   428  			var success bool
   429  			defer func() {
   430  				shutdownChan <- success
   431  			}()
   432  			if err := server.Shutdown(ctx); err != nil {
   433  				// Error from closing listeners, or context timeout:
   434  				s.Logf("HTTP server Shutdown: %v", err)
   435  			} else {
   436  				success = true
   437  			}
   438  		}()
   439  	}
   440  
   441  	// Wait until all listeners have successfully shut down before calling ServerShutdown
   442  	success := true
   443  	for range servers {
   444  		success = success && <-shutdownChan
   445  	}
   446  	if success {
   447  		s.api.ServerShutdown()
   448  	}
   449  }
   450  
   451  // GetHandler returns a handler useful for testing
   452  func (s *Server) GetHandler() http.Handler {
   453  	return s.handler
   454  }
   455  
   456  // SetHandler allows for setting a http handler on this server
   457  func (s *Server) SetHandler(handler http.Handler) {
   458  	s.handler = handler
   459  }
   460  
   461  // UnixListener returns the domain socket listener
   462  func (s *Server) UnixListener() (net.Listener, error) {
   463  	if !s.hasListeners {
   464  		if err := s.Listen(); err != nil {
   465  			return nil, err
   466  		}
   467  	}
   468  	return s.domainSocketL, nil
   469  }
   470  
   471  // HTTPListener returns the http listener
   472  func (s *Server) HTTPListener() (net.Listener, error) {
   473  	if !s.hasListeners {
   474  		if err := s.Listen(); err != nil {
   475  			return nil, err
   476  		}
   477  	}
   478  	return s.httpServerL, nil
   479  }
   480  
   481  // TLSListener returns the https listener
   482  func (s *Server) TLSListener() (net.Listener, error) {
   483  	if !s.hasListeners {
   484  		if err := s.Listen(); err != nil {
   485  			return nil, err
   486  		}
   487  	}
   488  	return s.httpsServerL, nil
   489  }
   490  
   491  func handleInterrupt(once *sync.Once, s *Server) {
   492  	once.Do(func() {
   493  		for _ = range s.interrupt {
   494  			if s.interrupted {
   495  				s.Logf("Server already shutting down")
   496  				continue
   497  			}
   498  			s.interrupted = true
   499  			s.Logf("Shutting down... ")
   500  			if err := s.Shutdown(); err != nil {
   501  				s.Logf("HTTP server Shutdown: %v", err)
   502  			}
   503  		}
   504  	})
   505  }
   506  
   507  func signalNotify(interrupt chan<- os.Signal) {
   508  	signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
   509  }