github.com/weaviate/weaviate@v1.24.6/adapters/handlers/rest/server.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  // Code generated by go-swagger; DO NOT EDIT.
    13  
    14  package rest
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"errors"
    21  	"fmt"
    22  	"log"
    23  	"net"
    24  	"net/http"
    25  	"os"
    26  	"os/signal"
    27  	"strconv"
    28  	"sync"
    29  	"sync/atomic"
    30  	"syscall"
    31  	"time"
    32  
    33  	"github.com/go-openapi/runtime/flagext"
    34  	"github.com/go-openapi/swag"
    35  	flags "github.com/jessevdk/go-flags"
    36  	"golang.org/x/net/netutil"
    37  
    38  	"github.com/weaviate/weaviate/adapters/handlers/rest/operations"
    39  )
    40  
    41  const (
    42  	schemeHTTP  = "http"
    43  	schemeHTTPS = "https"
    44  	schemeUnix  = "unix"
    45  )
    46  
    47  var defaultSchemes []string
    48  
    49  func init() {
    50  	defaultSchemes = []string{
    51  		schemeHTTPS,
    52  	}
    53  }
    54  
    55  // NewServer creates a new api weaviate server but does not configure it
    56  func NewServer(api *operations.WeaviateAPI) *Server {
    57  	s := new(Server)
    58  
    59  	s.shutdown = make(chan struct{})
    60  	s.api = api
    61  	s.interrupt = make(chan os.Signal, 1)
    62  	return s
    63  }
    64  
    65  // ConfigureAPI configures the API and handlers.
    66  func (s *Server) ConfigureAPI() {
    67  	if s.api != nil {
    68  		s.handler = configureAPI(s.api)
    69  	}
    70  }
    71  
    72  // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse
    73  func (s *Server) ConfigureFlags() {
    74  	if s.api != nil {
    75  		configureFlags(s.api)
    76  	}
    77  }
    78  
    79  // Server for the weaviate API
    80  type Server struct {
    81  	EnabledListeners []string         `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"`
    82  	CleanupTimeout   time.Duration    `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"`
    83  	GracefulTimeout  time.Duration    `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"`
    84  	MaxHeaderSize    flagext.ByteSize `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"`
    85  
    86  	SocketPath    flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/weaviate.sock"`
    87  	domainSocketL net.Listener
    88  
    89  	Host         string        `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"`
    90  	Port         int           `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"`
    91  	ListenLimit  int           `long:"listen-limit" description:"limit the number of outstanding requests"`
    92  	KeepAlive    time.Duration `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"`
    93  	ReadTimeout  time.Duration `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"`
    94  	WriteTimeout time.Duration `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"60s"`
    95  	httpServerL  net.Listener
    96  
    97  	TLSHost           string         `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"`
    98  	TLSPort           int            `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"`
    99  	TLSCertificate    flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"`
   100  	TLSCertificateKey flags.Filename `long:"tls-key" description:"the private key to use for secure connections" env:"TLS_PRIVATE_KEY"`
   101  	TLSCACertificate  flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"`
   102  	TLSListenLimit    int            `long:"tls-listen-limit" description:"limit the number of outstanding requests"`
   103  	TLSKeepAlive      time.Duration  `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"`
   104  	TLSReadTimeout    time.Duration  `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"`
   105  	TLSWriteTimeout   time.Duration  `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"`
   106  	httpsServerL      net.Listener
   107  
   108  	api          *operations.WeaviateAPI
   109  	handler      http.Handler
   110  	hasListeners bool
   111  	shutdown     chan struct{}
   112  	shuttingDown int32
   113  	interrupted  bool
   114  	interrupt    chan os.Signal
   115  }
   116  
   117  // Logf logs message either via defined user logger or via system one if no user logger is defined.
   118  func (s *Server) Logf(f string, args ...interface{}) {
   119  	if s.api != nil && s.api.Logger != nil {
   120  		s.api.Logger(f, args...)
   121  	} else {
   122  		log.Printf(f, args...)
   123  	}
   124  }
   125  
   126  // Fatalf logs message either via defined user logger or via system one if no user logger is defined.
   127  // Exits with non-zero status after printing
   128  func (s *Server) Fatalf(f string, args ...interface{}) {
   129  	if s.api != nil && s.api.Logger != nil {
   130  		s.api.Logger(f, args...)
   131  		os.Exit(1)
   132  	} else {
   133  		log.Fatalf(f, args...)
   134  	}
   135  }
   136  
   137  // SetAPI configures the server with the specified API. Needs to be called before Serve
   138  func (s *Server) SetAPI(api *operations.WeaviateAPI) {
   139  	if api == nil {
   140  		s.api = nil
   141  		s.handler = nil
   142  		return
   143  	}
   144  
   145  	s.api = api
   146  	s.handler = configureAPI(api)
   147  }
   148  
   149  func (s *Server) hasScheme(scheme string) bool {
   150  	schemes := s.EnabledListeners
   151  	if len(schemes) == 0 {
   152  		schemes = defaultSchemes
   153  	}
   154  
   155  	for _, v := range schemes {
   156  		if v == scheme {
   157  			return true
   158  		}
   159  	}
   160  	return false
   161  }
   162  
   163  // Serve the api
   164  func (s *Server) Serve() (err error) {
   165  	if !s.hasListeners {
   166  		if err = s.Listen(); err != nil {
   167  			return err
   168  		}
   169  	}
   170  
   171  	// set default handler, if none is set
   172  	if s.handler == nil {
   173  		if s.api == nil {
   174  			return errors.New("can't create the default handler, as no api is set")
   175  		}
   176  
   177  		s.SetHandler(s.api.Serve(nil))
   178  	}
   179  
   180  	wg := new(sync.WaitGroup)
   181  	once := new(sync.Once)
   182  	signalNotify(s.interrupt)
   183  	go handleInterrupt(once, s)
   184  
   185  	servers := []*http.Server{}
   186  
   187  	if s.hasScheme(schemeUnix) {
   188  		domainSocket := new(http.Server)
   189  		domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize)
   190  		domainSocket.Handler = s.handler
   191  		if int64(s.CleanupTimeout) > 0 {
   192  			domainSocket.IdleTimeout = s.CleanupTimeout
   193  		}
   194  
   195  		configureServer(domainSocket, "unix", string(s.SocketPath))
   196  
   197  		servers = append(servers, domainSocket)
   198  		wg.Add(1)
   199  		s.Logf("Serving weaviate at unix://%s", s.SocketPath)
   200  		go func(l net.Listener) {
   201  			defer wg.Done()
   202  			if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed {
   203  				s.Fatalf("%v", err)
   204  			}
   205  			s.Logf("Stopped serving weaviate at unix://%s", s.SocketPath)
   206  		}(s.domainSocketL)
   207  	}
   208  
   209  	if s.hasScheme(schemeHTTP) {
   210  		httpServer := new(http.Server)
   211  		httpServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   212  		httpServer.ReadTimeout = s.ReadTimeout
   213  		httpServer.WriteTimeout = s.WriteTimeout
   214  		httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0)
   215  		if s.ListenLimit > 0 {
   216  			s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit)
   217  		}
   218  
   219  		if int64(s.CleanupTimeout) > 0 {
   220  			httpServer.IdleTimeout = s.CleanupTimeout
   221  		}
   222  
   223  		httpServer.Handler = s.handler
   224  
   225  		configureServer(httpServer, "http", s.httpServerL.Addr().String())
   226  
   227  		servers = append(servers, httpServer)
   228  		wg.Add(1)
   229  		s.Logf("Serving weaviate at http://%s", s.httpServerL.Addr())
   230  		go func(l net.Listener) {
   231  			defer wg.Done()
   232  			if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed {
   233  				s.Fatalf("%v", err)
   234  			}
   235  			s.Logf("Stopped serving weaviate at http://%s", l.Addr())
   236  		}(s.httpServerL)
   237  	}
   238  
   239  	if s.hasScheme(schemeHTTPS) {
   240  		httpsServer := new(http.Server)
   241  		httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize)
   242  		httpsServer.ReadTimeout = s.TLSReadTimeout
   243  		httpsServer.WriteTimeout = s.TLSWriteTimeout
   244  		httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0)
   245  		if s.TLSListenLimit > 0 {
   246  			s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit)
   247  		}
   248  		if int64(s.CleanupTimeout) > 0 {
   249  			httpsServer.IdleTimeout = s.CleanupTimeout
   250  		}
   251  		httpsServer.Handler = s.handler
   252  
   253  		// Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go
   254  		httpsServer.TLSConfig = &tls.Config{
   255  			// Causes servers to use Go's default ciphersuite preferences,
   256  			// which are tuned to avoid attacks. Does nothing on clients.
   257  			PreferServerCipherSuites: true,
   258  			// Only use curves which have assembly implementations
   259  			// https://github.com/golang/go/tree/master/src/crypto/elliptic
   260  			CurvePreferences: []tls.CurveID{tls.CurveP256},
   261  			// Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility
   262  			NextProtos: []string{"h2", "http/1.1"},
   263  			// https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols
   264  			MinVersion: tls.VersionTLS12,
   265  			// These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy
   266  			CipherSuites: []uint16{
   267  				tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
   268  				tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
   269  				tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   270  				tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   271  				tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
   272  				tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
   273  			},
   274  		}
   275  
   276  		// build standard config from server options
   277  		if s.TLSCertificate != "" && s.TLSCertificateKey != "" {
   278  			httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1)
   279  			httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(string(s.TLSCertificate), string(s.TLSCertificateKey))
   280  			if err != nil {
   281  				return err
   282  			}
   283  		}
   284  
   285  		if s.TLSCACertificate != "" {
   286  			// include specified CA certificate
   287  			caCert, caCertErr := os.ReadFile(string(s.TLSCACertificate))
   288  			if caCertErr != nil {
   289  				return caCertErr
   290  			}
   291  			caCertPool := x509.NewCertPool()
   292  			ok := caCertPool.AppendCertsFromPEM(caCert)
   293  			if !ok {
   294  				return fmt.Errorf("cannot parse CA certificate")
   295  			}
   296  			httpsServer.TLSConfig.ClientCAs = caCertPool
   297  			httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
   298  		}
   299  
   300  		// call custom TLS configurator
   301  		configureTLS(httpsServer.TLSConfig)
   302  
   303  		if len(httpsServer.TLSConfig.Certificates) == 0 && httpsServer.TLSConfig.GetCertificate == nil {
   304  			// after standard and custom config are passed, this ends up with no certificate
   305  			if s.TLSCertificate == "" {
   306  				if s.TLSCertificateKey == "" {
   307  					s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified")
   308  				}
   309  				s.Fatalf("the required flag `--tls-certificate` was not specified")
   310  			}
   311  			if s.TLSCertificateKey == "" {
   312  				s.Fatalf("the required flag `--tls-key` was not specified")
   313  			}
   314  			// this happens with a wrong custom TLS configurator
   315  			s.Fatalf("no certificate was configured for TLS")
   316  		}
   317  
   318  		configureServer(httpsServer, "https", s.httpsServerL.Addr().String())
   319  
   320  		servers = append(servers, httpsServer)
   321  		wg.Add(1)
   322  		s.Logf("Serving weaviate at https://%s", s.httpsServerL.Addr())
   323  		go func(l net.Listener) {
   324  			defer wg.Done()
   325  			if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed {
   326  				s.Fatalf("%v", err)
   327  			}
   328  			s.Logf("Stopped serving weaviate at https://%s", l.Addr())
   329  		}(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig))
   330  	}
   331  
   332  	wg.Add(1)
   333  	go s.handleShutdown(wg, &servers)
   334  
   335  	wg.Wait()
   336  	return nil
   337  }
   338  
   339  // Listen creates the listeners for the server
   340  func (s *Server) Listen() error {
   341  	if s.hasListeners { // already done this
   342  		return nil
   343  	}
   344  
   345  	if s.hasScheme(schemeHTTPS) {
   346  		// Use http host if https host wasn't defined
   347  		if s.TLSHost == "" {
   348  			s.TLSHost = s.Host
   349  		}
   350  		// Use http listen limit if https listen limit wasn't defined
   351  		if s.TLSListenLimit == 0 {
   352  			s.TLSListenLimit = s.ListenLimit
   353  		}
   354  		// Use http tcp keep alive if https tcp keep alive wasn't defined
   355  		if int64(s.TLSKeepAlive) == 0 {
   356  			s.TLSKeepAlive = s.KeepAlive
   357  		}
   358  		// Use http read timeout if https read timeout wasn't defined
   359  		if int64(s.TLSReadTimeout) == 0 {
   360  			s.TLSReadTimeout = s.ReadTimeout
   361  		}
   362  		// Use http write timeout if https write timeout wasn't defined
   363  		if int64(s.TLSWriteTimeout) == 0 {
   364  			s.TLSWriteTimeout = s.WriteTimeout
   365  		}
   366  	}
   367  
   368  	if s.hasScheme(schemeUnix) {
   369  		domSockListener, err := net.Listen("unix", string(s.SocketPath))
   370  		if err != nil {
   371  			return err
   372  		}
   373  		s.domainSocketL = domSockListener
   374  	}
   375  
   376  	if s.hasScheme(schemeHTTP) {
   377  		listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)))
   378  		if err != nil {
   379  			return err
   380  		}
   381  
   382  		h, p, err := swag.SplitHostPort(listener.Addr().String())
   383  		if err != nil {
   384  			return err
   385  		}
   386  		s.Host = h
   387  		s.Port = p
   388  		s.httpServerL = listener
   389  	}
   390  
   391  	if s.hasScheme(schemeHTTPS) {
   392  		tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort)))
   393  		if err != nil {
   394  			return err
   395  		}
   396  
   397  		sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String())
   398  		if err != nil {
   399  			return err
   400  		}
   401  		s.TLSHost = sh
   402  		s.TLSPort = sp
   403  		s.httpsServerL = tlsListener
   404  	}
   405  
   406  	s.hasListeners = true
   407  	return nil
   408  }
   409  
   410  // Shutdown server and clean up resources
   411  func (s *Server) Shutdown() error {
   412  	if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) {
   413  		close(s.shutdown)
   414  	}
   415  	return nil
   416  }
   417  
   418  func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) {
   419  	// wg.Done must occur last, after s.api.ServerShutdown()
   420  	// (to preserve old behaviour)
   421  	defer wg.Done()
   422  
   423  	<-s.shutdown
   424  
   425  	servers := *serversPtr
   426  
   427  	ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout)
   428  	defer cancel()
   429  
   430  	// first execute the pre-shutdown hook
   431  	s.api.PreServerShutdown()
   432  
   433  	shutdownChan := make(chan bool)
   434  	for i := range servers {
   435  		server := servers[i]
   436  		go func() {
   437  			var success bool
   438  			defer func() {
   439  				shutdownChan <- success
   440  			}()
   441  			if err := server.Shutdown(ctx); err != nil {
   442  				// Error from closing listeners, or context timeout:
   443  				s.Logf("HTTP server Shutdown: %v", err)
   444  			} else {
   445  				success = true
   446  			}
   447  		}()
   448  	}
   449  
   450  	// Wait until all listeners have successfully shut down before calling ServerShutdown
   451  	success := true
   452  	for range servers {
   453  		success = success && <-shutdownChan
   454  	}
   455  	if success {
   456  		s.api.ServerShutdown()
   457  	}
   458  }
   459  
   460  // GetHandler returns a handler useful for testing
   461  func (s *Server) GetHandler() http.Handler {
   462  	return s.handler
   463  }
   464  
   465  // SetHandler allows for setting a http handler on this server
   466  func (s *Server) SetHandler(handler http.Handler) {
   467  	s.handler = handler
   468  }
   469  
   470  // UnixListener returns the domain socket listener
   471  func (s *Server) UnixListener() (net.Listener, error) {
   472  	if !s.hasListeners {
   473  		if err := s.Listen(); err != nil {
   474  			return nil, err
   475  		}
   476  	}
   477  	return s.domainSocketL, nil
   478  }
   479  
   480  // HTTPListener returns the http listener
   481  func (s *Server) HTTPListener() (net.Listener, error) {
   482  	if !s.hasListeners {
   483  		if err := s.Listen(); err != nil {
   484  			return nil, err
   485  		}
   486  	}
   487  	return s.httpServerL, nil
   488  }
   489  
   490  // TLSListener returns the https listener
   491  func (s *Server) TLSListener() (net.Listener, error) {
   492  	if !s.hasListeners {
   493  		if err := s.Listen(); err != nil {
   494  			return nil, err
   495  		}
   496  	}
   497  	return s.httpsServerL, nil
   498  }
   499  
   500  func handleInterrupt(once *sync.Once, s *Server) {
   501  	once.Do(func() {
   502  		for range s.interrupt {
   503  			if s.interrupted {
   504  				s.Logf("Server already shutting down")
   505  				continue
   506  			}
   507  			s.interrupted = true
   508  			s.Logf("Shutting down... ")
   509  			if err := s.Shutdown(); err != nil {
   510  				s.Logf("HTTP server Shutdown: %v", err)
   511  			}
   512  		}
   513  	})
   514  }
   515  
   516  func signalNotify(interrupt chan<- os.Signal) {
   517  	signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
   518  }