github.com/weaviate/weaviate@v1.24.6/adapters/handlers/rest/server.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 // Code generated by go-swagger; DO NOT EDIT. 13 14 package rest 15 16 import ( 17 "context" 18 "crypto/tls" 19 "crypto/x509" 20 "errors" 21 "fmt" 22 "log" 23 "net" 24 "net/http" 25 "os" 26 "os/signal" 27 "strconv" 28 "sync" 29 "sync/atomic" 30 "syscall" 31 "time" 32 33 "github.com/go-openapi/runtime/flagext" 34 "github.com/go-openapi/swag" 35 flags "github.com/jessevdk/go-flags" 36 "golang.org/x/net/netutil" 37 38 "github.com/weaviate/weaviate/adapters/handlers/rest/operations" 39 ) 40 41 const ( 42 schemeHTTP = "http" 43 schemeHTTPS = "https" 44 schemeUnix = "unix" 45 ) 46 47 var defaultSchemes []string 48 49 func init() { 50 defaultSchemes = []string{ 51 schemeHTTPS, 52 } 53 } 54 55 // NewServer creates a new api weaviate server but does not configure it 56 func NewServer(api *operations.WeaviateAPI) *Server { 57 s := new(Server) 58 59 s.shutdown = make(chan struct{}) 60 s.api = api 61 s.interrupt = make(chan os.Signal, 1) 62 return s 63 } 64 65 // ConfigureAPI configures the API and handlers. 66 func (s *Server) ConfigureAPI() { 67 if s.api != nil { 68 s.handler = configureAPI(s.api) 69 } 70 } 71 72 // ConfigureFlags configures the additional flags defined by the handlers. Needs to be called before the parser.Parse 73 func (s *Server) ConfigureFlags() { 74 if s.api != nil { 75 configureFlags(s.api) 76 } 77 } 78 79 // Server for the weaviate API 80 type Server struct { 81 EnabledListeners []string `long:"scheme" description:"the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec"` 82 CleanupTimeout time.Duration `long:"cleanup-timeout" description:"grace period for which to wait before killing idle connections" default:"10s"` 83 GracefulTimeout time.Duration `long:"graceful-timeout" description:"grace period for which to wait before shutting down the server" default:"15s"` 84 MaxHeaderSize flagext.ByteSize `long:"max-header-size" description:"controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body." default:"1MiB"` 85 86 SocketPath flags.Filename `long:"socket-path" description:"the unix socket to listen on" default:"/var/run/weaviate.sock"` 87 domainSocketL net.Listener 88 89 Host string `long:"host" description:"the IP to listen on" default:"localhost" env:"HOST"` 90 Port int `long:"port" description:"the port to listen on for insecure connections, defaults to a random value" env:"PORT"` 91 ListenLimit int `long:"listen-limit" description:"limit the number of outstanding requests"` 92 KeepAlive time.Duration `long:"keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)" default:"3m"` 93 ReadTimeout time.Duration `long:"read-timeout" description:"maximum duration before timing out read of the request" default:"30s"` 94 WriteTimeout time.Duration `long:"write-timeout" description:"maximum duration before timing out write of the response" default:"60s"` 95 httpServerL net.Listener 96 97 TLSHost string `long:"tls-host" description:"the IP to listen on for tls, when not specified it's the same as --host" env:"TLS_HOST"` 98 TLSPort int `long:"tls-port" description:"the port to listen on for secure connections, defaults to a random value" env:"TLS_PORT"` 99 TLSCertificate flags.Filename `long:"tls-certificate" description:"the certificate to use for secure connections" env:"TLS_CERTIFICATE"` 100 TLSCertificateKey flags.Filename `long:"tls-key" description:"the private key to use for secure connections" env:"TLS_PRIVATE_KEY"` 101 TLSCACertificate flags.Filename `long:"tls-ca" description:"the certificate authority file to be used with mutual tls auth" env:"TLS_CA_CERTIFICATE"` 102 TLSListenLimit int `long:"tls-listen-limit" description:"limit the number of outstanding requests"` 103 TLSKeepAlive time.Duration `long:"tls-keep-alive" description:"sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)"` 104 TLSReadTimeout time.Duration `long:"tls-read-timeout" description:"maximum duration before timing out read of the request"` 105 TLSWriteTimeout time.Duration `long:"tls-write-timeout" description:"maximum duration before timing out write of the response"` 106 httpsServerL net.Listener 107 108 api *operations.WeaviateAPI 109 handler http.Handler 110 hasListeners bool 111 shutdown chan struct{} 112 shuttingDown int32 113 interrupted bool 114 interrupt chan os.Signal 115 } 116 117 // Logf logs message either via defined user logger or via system one if no user logger is defined. 118 func (s *Server) Logf(f string, args ...interface{}) { 119 if s.api != nil && s.api.Logger != nil { 120 s.api.Logger(f, args...) 121 } else { 122 log.Printf(f, args...) 123 } 124 } 125 126 // Fatalf logs message either via defined user logger or via system one if no user logger is defined. 127 // Exits with non-zero status after printing 128 func (s *Server) Fatalf(f string, args ...interface{}) { 129 if s.api != nil && s.api.Logger != nil { 130 s.api.Logger(f, args...) 131 os.Exit(1) 132 } else { 133 log.Fatalf(f, args...) 134 } 135 } 136 137 // SetAPI configures the server with the specified API. Needs to be called before Serve 138 func (s *Server) SetAPI(api *operations.WeaviateAPI) { 139 if api == nil { 140 s.api = nil 141 s.handler = nil 142 return 143 } 144 145 s.api = api 146 s.handler = configureAPI(api) 147 } 148 149 func (s *Server) hasScheme(scheme string) bool { 150 schemes := s.EnabledListeners 151 if len(schemes) == 0 { 152 schemes = defaultSchemes 153 } 154 155 for _, v := range schemes { 156 if v == scheme { 157 return true 158 } 159 } 160 return false 161 } 162 163 // Serve the api 164 func (s *Server) Serve() (err error) { 165 if !s.hasListeners { 166 if err = s.Listen(); err != nil { 167 return err 168 } 169 } 170 171 // set default handler, if none is set 172 if s.handler == nil { 173 if s.api == nil { 174 return errors.New("can't create the default handler, as no api is set") 175 } 176 177 s.SetHandler(s.api.Serve(nil)) 178 } 179 180 wg := new(sync.WaitGroup) 181 once := new(sync.Once) 182 signalNotify(s.interrupt) 183 go handleInterrupt(once, s) 184 185 servers := []*http.Server{} 186 187 if s.hasScheme(schemeUnix) { 188 domainSocket := new(http.Server) 189 domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize) 190 domainSocket.Handler = s.handler 191 if int64(s.CleanupTimeout) > 0 { 192 domainSocket.IdleTimeout = s.CleanupTimeout 193 } 194 195 configureServer(domainSocket, "unix", string(s.SocketPath)) 196 197 servers = append(servers, domainSocket) 198 wg.Add(1) 199 s.Logf("Serving weaviate at unix://%s", s.SocketPath) 200 go func(l net.Listener) { 201 defer wg.Done() 202 if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed { 203 s.Fatalf("%v", err) 204 } 205 s.Logf("Stopped serving weaviate at unix://%s", s.SocketPath) 206 }(s.domainSocketL) 207 } 208 209 if s.hasScheme(schemeHTTP) { 210 httpServer := new(http.Server) 211 httpServer.MaxHeaderBytes = int(s.MaxHeaderSize) 212 httpServer.ReadTimeout = s.ReadTimeout 213 httpServer.WriteTimeout = s.WriteTimeout 214 httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0) 215 if s.ListenLimit > 0 { 216 s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit) 217 } 218 219 if int64(s.CleanupTimeout) > 0 { 220 httpServer.IdleTimeout = s.CleanupTimeout 221 } 222 223 httpServer.Handler = s.handler 224 225 configureServer(httpServer, "http", s.httpServerL.Addr().String()) 226 227 servers = append(servers, httpServer) 228 wg.Add(1) 229 s.Logf("Serving weaviate at http://%s", s.httpServerL.Addr()) 230 go func(l net.Listener) { 231 defer wg.Done() 232 if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed { 233 s.Fatalf("%v", err) 234 } 235 s.Logf("Stopped serving weaviate at http://%s", l.Addr()) 236 }(s.httpServerL) 237 } 238 239 if s.hasScheme(schemeHTTPS) { 240 httpsServer := new(http.Server) 241 httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize) 242 httpsServer.ReadTimeout = s.TLSReadTimeout 243 httpsServer.WriteTimeout = s.TLSWriteTimeout 244 httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0) 245 if s.TLSListenLimit > 0 { 246 s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit) 247 } 248 if int64(s.CleanupTimeout) > 0 { 249 httpsServer.IdleTimeout = s.CleanupTimeout 250 } 251 httpsServer.Handler = s.handler 252 253 // Inspired by https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go 254 httpsServer.TLSConfig = &tls.Config{ 255 // Causes servers to use Go's default ciphersuite preferences, 256 // which are tuned to avoid attacks. Does nothing on clients. 257 PreferServerCipherSuites: true, 258 // Only use curves which have assembly implementations 259 // https://github.com/golang/go/tree/master/src/crypto/elliptic 260 CurvePreferences: []tls.CurveID{tls.CurveP256}, 261 // Use modern tls mode https://wiki.mozilla.org/Security/Server_Side_TLS#Modern_compatibility 262 NextProtos: []string{"h2", "http/1.1"}, 263 // https://www.owasp.org/index.php/Transport_Layer_Protection_Cheat_Sheet#Rule_-_Only_Support_Strong_Protocols 264 MinVersion: tls.VersionTLS12, 265 // These ciphersuites support Forward Secrecy: https://en.wikipedia.org/wiki/Forward_secrecy 266 CipherSuites: []uint16{ 267 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 268 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 269 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 270 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 271 tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 272 tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 273 }, 274 } 275 276 // build standard config from server options 277 if s.TLSCertificate != "" && s.TLSCertificateKey != "" { 278 httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1) 279 httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(string(s.TLSCertificate), string(s.TLSCertificateKey)) 280 if err != nil { 281 return err 282 } 283 } 284 285 if s.TLSCACertificate != "" { 286 // include specified CA certificate 287 caCert, caCertErr := os.ReadFile(string(s.TLSCACertificate)) 288 if caCertErr != nil { 289 return caCertErr 290 } 291 caCertPool := x509.NewCertPool() 292 ok := caCertPool.AppendCertsFromPEM(caCert) 293 if !ok { 294 return fmt.Errorf("cannot parse CA certificate") 295 } 296 httpsServer.TLSConfig.ClientCAs = caCertPool 297 httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert 298 } 299 300 // call custom TLS configurator 301 configureTLS(httpsServer.TLSConfig) 302 303 if len(httpsServer.TLSConfig.Certificates) == 0 && httpsServer.TLSConfig.GetCertificate == nil { 304 // after standard and custom config are passed, this ends up with no certificate 305 if s.TLSCertificate == "" { 306 if s.TLSCertificateKey == "" { 307 s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified") 308 } 309 s.Fatalf("the required flag `--tls-certificate` was not specified") 310 } 311 if s.TLSCertificateKey == "" { 312 s.Fatalf("the required flag `--tls-key` was not specified") 313 } 314 // this happens with a wrong custom TLS configurator 315 s.Fatalf("no certificate was configured for TLS") 316 } 317 318 configureServer(httpsServer, "https", s.httpsServerL.Addr().String()) 319 320 servers = append(servers, httpsServer) 321 wg.Add(1) 322 s.Logf("Serving weaviate at https://%s", s.httpsServerL.Addr()) 323 go func(l net.Listener) { 324 defer wg.Done() 325 if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed { 326 s.Fatalf("%v", err) 327 } 328 s.Logf("Stopped serving weaviate at https://%s", l.Addr()) 329 }(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig)) 330 } 331 332 wg.Add(1) 333 go s.handleShutdown(wg, &servers) 334 335 wg.Wait() 336 return nil 337 } 338 339 // Listen creates the listeners for the server 340 func (s *Server) Listen() error { 341 if s.hasListeners { // already done this 342 return nil 343 } 344 345 if s.hasScheme(schemeHTTPS) { 346 // Use http host if https host wasn't defined 347 if s.TLSHost == "" { 348 s.TLSHost = s.Host 349 } 350 // Use http listen limit if https listen limit wasn't defined 351 if s.TLSListenLimit == 0 { 352 s.TLSListenLimit = s.ListenLimit 353 } 354 // Use http tcp keep alive if https tcp keep alive wasn't defined 355 if int64(s.TLSKeepAlive) == 0 { 356 s.TLSKeepAlive = s.KeepAlive 357 } 358 // Use http read timeout if https read timeout wasn't defined 359 if int64(s.TLSReadTimeout) == 0 { 360 s.TLSReadTimeout = s.ReadTimeout 361 } 362 // Use http write timeout if https write timeout wasn't defined 363 if int64(s.TLSWriteTimeout) == 0 { 364 s.TLSWriteTimeout = s.WriteTimeout 365 } 366 } 367 368 if s.hasScheme(schemeUnix) { 369 domSockListener, err := net.Listen("unix", string(s.SocketPath)) 370 if err != nil { 371 return err 372 } 373 s.domainSocketL = domSockListener 374 } 375 376 if s.hasScheme(schemeHTTP) { 377 listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port))) 378 if err != nil { 379 return err 380 } 381 382 h, p, err := swag.SplitHostPort(listener.Addr().String()) 383 if err != nil { 384 return err 385 } 386 s.Host = h 387 s.Port = p 388 s.httpServerL = listener 389 } 390 391 if s.hasScheme(schemeHTTPS) { 392 tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort))) 393 if err != nil { 394 return err 395 } 396 397 sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String()) 398 if err != nil { 399 return err 400 } 401 s.TLSHost = sh 402 s.TLSPort = sp 403 s.httpsServerL = tlsListener 404 } 405 406 s.hasListeners = true 407 return nil 408 } 409 410 // Shutdown server and clean up resources 411 func (s *Server) Shutdown() error { 412 if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) { 413 close(s.shutdown) 414 } 415 return nil 416 } 417 418 func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) { 419 // wg.Done must occur last, after s.api.ServerShutdown() 420 // (to preserve old behaviour) 421 defer wg.Done() 422 423 <-s.shutdown 424 425 servers := *serversPtr 426 427 ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout) 428 defer cancel() 429 430 // first execute the pre-shutdown hook 431 s.api.PreServerShutdown() 432 433 shutdownChan := make(chan bool) 434 for i := range servers { 435 server := servers[i] 436 go func() { 437 var success bool 438 defer func() { 439 shutdownChan <- success 440 }() 441 if err := server.Shutdown(ctx); err != nil { 442 // Error from closing listeners, or context timeout: 443 s.Logf("HTTP server Shutdown: %v", err) 444 } else { 445 success = true 446 } 447 }() 448 } 449 450 // Wait until all listeners have successfully shut down before calling ServerShutdown 451 success := true 452 for range servers { 453 success = success && <-shutdownChan 454 } 455 if success { 456 s.api.ServerShutdown() 457 } 458 } 459 460 // GetHandler returns a handler useful for testing 461 func (s *Server) GetHandler() http.Handler { 462 return s.handler 463 } 464 465 // SetHandler allows for setting a http handler on this server 466 func (s *Server) SetHandler(handler http.Handler) { 467 s.handler = handler 468 } 469 470 // UnixListener returns the domain socket listener 471 func (s *Server) UnixListener() (net.Listener, error) { 472 if !s.hasListeners { 473 if err := s.Listen(); err != nil { 474 return nil, err 475 } 476 } 477 return s.domainSocketL, nil 478 } 479 480 // HTTPListener returns the http listener 481 func (s *Server) HTTPListener() (net.Listener, error) { 482 if !s.hasListeners { 483 if err := s.Listen(); err != nil { 484 return nil, err 485 } 486 } 487 return s.httpServerL, nil 488 } 489 490 // TLSListener returns the https listener 491 func (s *Server) TLSListener() (net.Listener, error) { 492 if !s.hasListeners { 493 if err := s.Listen(); err != nil { 494 return nil, err 495 } 496 } 497 return s.httpsServerL, nil 498 } 499 500 func handleInterrupt(once *sync.Once, s *Server) { 501 once.Do(func() { 502 for range s.interrupt { 503 if s.interrupted { 504 s.Logf("Server already shutting down") 505 continue 506 } 507 s.interrupted = true 508 s.Logf("Shutting down... ") 509 if err := s.Shutdown(); err != nil { 510 s.Logf("HTTP server Shutdown: %v", err) 511 } 512 } 513 }) 514 } 515 516 func signalNotify(interrupt chan<- os.Signal) { 517 signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) 518 }