github.com/kaisawind/go-swagger@v0.19.0/examples/authentication/restapi/server.go (about)

     1  // Code generated by go-swagger; DO NOT EDIT.
     2  
     3  package restapi
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"errors"
    10  	"fmt"
    11  	"io/ioutil"
    12  	"log"
    13  	"net"
    14  	"net/http"
    15  	"os"
    16  	"os/signal"
    17  	"strconv"
    18  	"sync"
    19  	"sync/atomic"
    20  	"syscall"
    21  	"time"
    22  
    23  	"github.com/go-openapi/runtime/flagext"
    24  	"github.com/go-openapi/swag"
    25  	flags "github.com/jessevdk/go-flags"
    26  	"golang.org/x/net/netutil"
    27  
    28  	"github.com/go-swagger/go-swagger/examples/authentication/restapi/operations"
    29  )
    30  
    31  const (
    32  	schemeHTTP  = "http"
    33  	schemeHTTPS = "https"
    34  	schemeUnix  = "unix"
    35  )
    36  
    37  var defaultSchemes []string
    38  
    39  func init() {
    40  	defaultSchemes = []string{
    41  		schemeHTTP,
    42  	}
    43  }
    44  
    45  // NewServer creates a new api auth sample server but does not configure it
    46  func NewServer(api *operations.AuthSampleAPI) *Server {
    47  	s := new(Server)
    48  
    49  	s.shutdown = make(chan struct{})
    50  	s.api = api
    51  	s.interrupt = make(chan os.Signal, 1)
    52  	return s
    53  }
    54  
    55  // ConfigureAPI configures the API and handlers.
    56  func (s *Server) ConfigureAPI() {
    57  	if s.api != nil {
    58  		s.handler = configureAPI(s.api)
    59  	}
    60  }
    61  
    62  // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse
    63  func (s *Server) ConfigureFlags() {
    64  	if s.api != nil {
    65  		configureFlags(s.api)
    66  	}
    67  }
    68  
    69  // Server for the auth sample API
    70  type Server struct {
    71  	EnabledListeners []string         `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"`
    72  	CleanupTimeout   time.Duration    `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"`
    73  	GracefulTimeout  time.Duration    `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"`
    74  	MaxHeaderSize    flagext.ByteSize `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"`
    75  
    76  	SocketPath    flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/auth-sample.sock"`
    77  	domainSocketL net.Listener
    78  
    79  	Host         string        `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"`
    80  	Port         int           `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"`
    81  	ListenLimit  int           `long:"listen-limit" description:"limit the number of outstanding requests"`
    82  	KeepAlive    time.Duration `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"`
    83  	ReadTimeout  time.Duration `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"`
    84  	WriteTimeout time.Duration `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"60s"`
    85  	httpServerL  net.Listener
    86  
    87  	TLSHost           string         `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"`
    88  	TLSPort           int            `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"`
    89  	TLSCertificate    flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"`
    90  	TLSCertificateKey flags.Filename `long:"tls-key" description:"the private key to use for secure connections" env:"TLS_PRIVATE_KEY"`
    91  	TLSCACertificate  flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"`
    92  	TLSListenLimit    int            `long:"tls-listen-limit" description:"limit the number of outstanding requests"`
    93  	TLSKeepAlive      time.Duration  `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"`
    94  	TLSReadTimeout    time.Duration  `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"`
    95  	TLSWriteTimeout   time.Duration  `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"`
    96  	httpsServerL      net.Listener
    97  
    98  	api          *operations.AuthSampleAPI
    99  	handler      http.Handler
   100  	hasListeners bool
   101  	shutdown     chan struct{}
   102  	shuttingDown int32
   103  	interrupted  bool
   104  	interrupt    chan os.Signal
   105  }
   106  
   107  // Logf logs message either via defined user logger or via system one if no user logger is defined.
   108  func (s *Server) Logf(f string, args ...interface{}) {
   109  	if s.api != nil && s.api.Logger != nil {
   110  		s.api.Logger(f, args...)
   111  	} else {
   112  		log.Printf(f, args...)
   113  	}
   114  }
   115  
   116  // Fatalf logs message either via defined user logger or via system one if no user logger is defined.
   117  // Exits with non-zero status after printing
   118  func (s *Server) Fatalf(f string, args ...interface{}) {
   119  	if s.api != nil && s.api.Logger != nil {
   120  		s.api.Logger(f, args...)
   121  		os.Exit(1)
   122  	} else {
   123  		log.Fatalf(f, args...)
   124  	}
   125  }
   126  
   127  // SetAPI configures the server with the specified API. Needs to be called before Serve
   128  func (s *Server) SetAPI(api *operations.AuthSampleAPI) {
   129  	if api == nil {
   130  		s.api = nil
   131  		s.handler = nil
   132  		return
   133  	}
   134  
   135  	s.api = api
   136  	s.api.Logger = log.Printf
   137  	s.handler = configureAPI(api)
   138  }
   139  
   140  func (s *Server) hasScheme(scheme string) bool {
   141  	schemes := s.EnabledListeners
   142  	if len(schemes) == 0 {
   143  		schemes = defaultSchemes
   144  	}
   145  
   146  	for _, v := range schemes {
   147  		if v == scheme {
   148  			return true
   149  		}
   150  	}
   151  	return false
   152  }
   153  
   154  // Serve the api
   155  func (s *Server) Serve() (err error) {
   156  	if !s.hasListeners {
   157  		if err = s.Listen(); err != nil {
   158  			return err
   159  		}
   160  	}
   161  
   162  	// set default handler, if none is set
   163  	if s.handler == nil {
   164  		if s.api == nil {
   165  			return errors.New("can't create the default handler, as no api is set")
   166  		}
   167  
   168  		s.SetHandler(s.api.Serve(nil))
   169  	}
   170  
   171  	wg := new(sync.WaitGroup)
   172  	once := new(sync.Once)
   173  	signalNotify(s.interrupt)
   174  	go handleInterrupt(once, s)
   175  
   176  	servers := []*http.Server{}
   177  	wg.Add(1)
   178  	go s.handleShutdown(wg, &servers)
   179  
   180  	if s.hasScheme(schemeUnix) {
   181  		domainSocket := new(http.Server)
   182  		domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize)
   183  		domainSocket.Handler = s.handler
   184  		if int64(s.CleanupTimeout) > 0 {
   185  			domainSocket.IdleTimeout = s.CleanupTimeout
   186  		}
   187  
   188  		configureServer(domainSocket, "unix", string(s.SocketPath))
   189  
   190  		servers = append(servers, domainSocket)
   191  		wg.Add(1)
   192  		s.Logf("Serving auth sample at unix://%s", s.SocketPath)
   193  		go func(l net.Listener) {
   194  			defer wg.Done()
   195  			if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed {
   196  				s.Fatalf("%v", err)
   197  			}
   198  			s.Logf("Stopped serving auth sample at unix://%s", s.SocketPath)
   199  		}(s.domainSocketL)
   200  	}
   201  
   202  	if s.hasScheme(schemeHTTP) {
   203  		httpServer := new(http.Server)
   204  		httpServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   205  		httpServer.ReadTimeout = s.ReadTimeout
   206  		httpServer.WriteTimeout = s.WriteTimeout
   207  		httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0)
   208  		if s.ListenLimit > 0 {
   209  			s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit)
   210  		}
   211  
   212  		if int64(s.CleanupTimeout) > 0 {
   213  			httpServer.IdleTimeout = s.CleanupTimeout
   214  		}
   215  
   216  		httpServer.Handler = s.handler
   217  
   218  		configureServer(httpServer, "http", s.httpServerL.Addr().String())
   219  
   220  		servers = append(servers, httpServer)
   221  		wg.Add(1)
   222  		s.Logf("Serving auth sample at http://%s", s.httpServerL.Addr())
   223  		go func(l net.Listener) {
   224  			defer wg.Done()
   225  			if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed {
   226  				s.Fatalf("%v", err)
   227  			}
   228  			s.Logf("Stopped serving auth sample at http://%s", l.Addr())
   229  		}(s.httpServerL)
   230  	}
   231  
   232  	if s.hasScheme(schemeHTTPS) {
   233  		httpsServer := new(http.Server)
   234  		httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   235  		httpsServer.ReadTimeout = s.TLSReadTimeout
   236  		httpsServer.WriteTimeout = s.TLSWriteTimeout
   237  		httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0)
   238  		if s.TLSListenLimit > 0 {
   239  			s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit)
   240  		}
   241  		if int64(s.CleanupTimeout) > 0 {
   242  			httpsServer.IdleTimeout = s.CleanupTimeout
   243  		}
   244  		httpsServer.Handler = s.handler
   245  
   246  		// Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go
   247  		httpsServer.TLSConfig = &tls.Config{
   248  			// Causes servers to use Go's default ciphersuite preferences,
   249  			// which are tuned to avoid attacks. Does nothing on clients.
   250  			PreferServerCipherSuites: true,
   251  			// Only use curves which have assembly implementations
   252  			// https://github.com/golang/go/tree/master/src/crypto/elliptic
   253  			CurvePreferences: []tls.CurveID{tls.CurveP256},
   254  			// Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility
   255  			NextProtos: []string{"http/1.1", "h2"},
   256  			// https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols
   257  			MinVersion: tls.VersionTLS12,
   258  			// These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy
   259  			CipherSuites: []uint16{
   260  				tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
   261  				tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
   262  				tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   263  				tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   264  				tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
   265  				tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
   266  			},
   267  		}
   268  
   269  		// build standard config from server options
   270  		if s.TLSCertificate != "" && s.TLSCertificateKey != "" {
   271  			httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1)
   272  			httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(string(s.TLSCertificate), string(s.TLSCertificateKey))
   273  			if err != nil {
   274  				return err
   275  			}
   276  		}
   277  
   278  		if s.TLSCACertificate != "" {
   279  			// include specified CA certificate
   280  			caCert, caCertErr := ioutil.ReadFile(string(s.TLSCACertificate))
   281  			if caCertErr != nil {
   282  				return caCertErr
   283  			}
   284  			caCertPool := x509.NewCertPool()
   285  			ok := caCertPool.AppendCertsFromPEM(caCert)
   286  			if !ok {
   287  				return fmt.Errorf("cannot parse CA certificate")
   288  			}
   289  			httpsServer.TLSConfig.ClientCAs = caCertPool
   290  			httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
   291  		}
   292  
   293  		// call custom TLS configurator
   294  		configureTLS(httpsServer.TLSConfig)
   295  
   296  		if len(httpsServer.TLSConfig.Certificates) == 0 {
   297  			// after standard and custom config are passed, this ends up with no certificate
   298  			if s.TLSCertificate == "" {
   299  				if s.TLSCertificateKey == "" {
   300  					s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified")
   301  				}
   302  				s.Fatalf("the required flag `--tls-certificate` was not specified")
   303  			}
   304  			if s.TLSCertificateKey == "" {
   305  				s.Fatalf("the required flag `--tls-key` was not specified")
   306  			}
   307  			// this happens with a wrong custom TLS configurator
   308  			s.Fatalf("no certificate was configured for TLS")
   309  		}
   310  
   311  		// must have at least one certificate or panics
   312  		httpsServer.TLSConfig.BuildNameToCertificate()
   313  
   314  		configureServer(httpsServer, "https", s.httpsServerL.Addr().String())
   315  
   316  		servers = append(servers, httpsServer)
   317  		wg.Add(1)
   318  		s.Logf("Serving auth sample at https://%s", s.httpsServerL.Addr())
   319  		go func(l net.Listener) {
   320  			defer wg.Done()
   321  			if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed {
   322  				s.Fatalf("%v", err)
   323  			}
   324  			s.Logf("Stopped serving auth sample at https://%s", l.Addr())
   325  		}(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig))
   326  	}
   327  
   328  	wg.Wait()
   329  	return nil
   330  }
   331  
   332  // Listen creates the listeners for the server
   333  func (s *Server) Listen() error {
   334  	if s.hasListeners { // already done this
   335  		return nil
   336  	}
   337  
   338  	if s.hasScheme(schemeHTTPS) {
   339  		// Use http host if https host wasn't defined
   340  		if s.TLSHost == "" {
   341  			s.TLSHost = s.Host
   342  		}
   343  		// Use http listen limit if https listen limit wasn't defined
   344  		if s.TLSListenLimit == 0 {
   345  			s.TLSListenLimit = s.ListenLimit
   346  		}
   347  		// Use http tcp keep alive if https tcp keep alive wasn't defined
   348  		if int64(s.TLSKeepAlive) == 0 {
   349  			s.TLSKeepAlive = s.KeepAlive
   350  		}
   351  		// Use http read timeout if https read timeout wasn't defined
   352  		if int64(s.TLSReadTimeout) == 0 {
   353  			s.TLSReadTimeout = s.ReadTimeout
   354  		}
   355  		// Use http write timeout if https write timeout wasn't defined
   356  		if int64(s.TLSWriteTimeout) == 0 {
   357  			s.TLSWriteTimeout = s.WriteTimeout
   358  		}
   359  	}
   360  
   361  	if s.hasScheme(schemeUnix) {
   362  		domSockListener, err := net.Listen("unix", string(s.SocketPath))
   363  		if err != nil {
   364  			return err
   365  		}
   366  		s.domainSocketL = domSockListener
   367  	}
   368  
   369  	if s.hasScheme(schemeHTTP) {
   370  		listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)))
   371  		if err != nil {
   372  			return err
   373  		}
   374  
   375  		h, p, err := swag.SplitHostPort(listener.Addr().String())
   376  		if err != nil {
   377  			return err
   378  		}
   379  		s.Host = h
   380  		s.Port = p
   381  		s.httpServerL = listener
   382  	}
   383  
   384  	if s.hasScheme(schemeHTTPS) {
   385  		tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort)))
   386  		if err != nil {
   387  			return err
   388  		}
   389  
   390  		sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String())
   391  		if err != nil {
   392  			return err
   393  		}
   394  		s.TLSHost = sh
   395  		s.TLSPort = sp
   396  		s.httpsServerL = tlsListener
   397  	}
   398  
   399  	s.hasListeners = true
   400  	return nil
   401  }
   402  
   403  // Shutdown server and clean up resources
   404  func (s *Server) Shutdown() error {
   405  	if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) {
   406  		close(s.shutdown)
   407  	}
   408  	return nil
   409  }
   410  
   411  func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) {
   412  	// wg.Done must occur last, after s.api.ServerShutdown()
   413  	// (to preserve old behaviour)
   414  	defer wg.Done()
   415  
   416  	<-s.shutdown
   417  
   418  	servers := *serversPtr
   419  
   420  	ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout)
   421  	defer cancel()
   422  
   423  	shutdownChan := make(chan bool)
   424  	for i := range servers {
   425  		server := servers[i]
   426  		go func() {
   427  			var success bool
   428  			defer func() {
   429  				shutdownChan <- success
   430  			}()
   431  			if err := server.Shutdown(ctx); err != nil {
   432  				// Error from closing listeners, or context timeout:
   433  				s.Logf("HTTP server Shutdown: %v", err)
   434  			} else {
   435  				success = true
   436  			}
   437  		}()
   438  	}
   439  
   440  	// Wait until all listeners have successfully shut down before calling ServerShutdown
   441  	success := true
   442  	for range servers {
   443  		success = success && <-shutdownChan
   444  	}
   445  	if success {
   446  		s.api.ServerShutdown()
   447  	}
   448  }
   449  
   450  // GetHandler returns a handler useful for testing
   451  func (s *Server) GetHandler() http.Handler {
   452  	return s.handler
   453  }
   454  
   455  // SetHandler allows for setting a http handler on this server
   456  func (s *Server) SetHandler(handler http.Handler) {
   457  	s.handler = handler
   458  }
   459  
   460  // UnixListener returns the domain socket listener
   461  func (s *Server) UnixListener() (net.Listener, error) {
   462  	if !s.hasListeners {
   463  		if err := s.Listen(); err != nil {
   464  			return nil, err
   465  		}
   466  	}
   467  	return s.domainSocketL, nil
   468  }
   469  
   470  // HTTPListener returns the http listener
   471  func (s *Server) HTTPListener() (net.Listener, error) {
   472  	if !s.hasListeners {
   473  		if err := s.Listen(); err != nil {
   474  			return nil, err
   475  		}
   476  	}
   477  	return s.httpServerL, nil
   478  }
   479  
   480  // TLSListener returns the https listener
   481  func (s *Server) TLSListener() (net.Listener, error) {
   482  	if !s.hasListeners {
   483  		if err := s.Listen(); err != nil {
   484  			return nil, err
   485  		}
   486  	}
   487  	return s.httpsServerL, nil
   488  }
   489  
   490  func handleInterrupt(once *sync.Once, s *Server) {
   491  	once.Do(func() {
   492  		for _ = range s.interrupt {
   493  			if s.interrupted {
   494  				s.Logf("Server already shutting down")
   495  				continue
   496  			}
   497  			s.interrupted = true
   498  			s.Logf("Shutting down... ")
   499  			if err := s.Shutdown(); err != nil {
   500  				s.Logf("HTTP server Shutdown: %v", err)
   501  			}
   502  		}
   503  	})
   504  }
   505  
   506  func signalNotify(interrupt chan<- os.Signal) {
   507  	signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
   508  }