code.gitea.io/gitea@v1.19.3/modules/graceful/server.go (about)

     1  // Copyright 2019 The Gitea Authors. All rights reserved.
     2  // SPDX-License-Identifier: MIT
     3  
     4  // This code is highly inspired by endless go
     5  
     6  package graceful
     7  
     8  import (
     9  	"crypto/tls"
    10  	"net"
    11  	"os"
    12  	"strings"
    13  	"sync"
    14  	"sync/atomic"
    15  	"syscall"
    16  	"time"
    17  
    18  	"code.gitea.io/gitea/modules/log"
    19  	"code.gitea.io/gitea/modules/proxyprotocol"
    20  	"code.gitea.io/gitea/modules/setting"
    21  )
    22  
    23  var (
    24  	// DefaultReadTimeOut default read timeout
    25  	DefaultReadTimeOut time.Duration
    26  	// DefaultWriteTimeOut default write timeout
    27  	DefaultWriteTimeOut time.Duration
    28  	// DefaultMaxHeaderBytes default max header bytes
    29  	DefaultMaxHeaderBytes int
    30  	// PerWriteWriteTimeout timeout for writes
    31  	PerWriteWriteTimeout = 30 * time.Second
    32  	// PerWriteWriteTimeoutKbTime is a timeout taking account of how much there is to be written
    33  	PerWriteWriteTimeoutKbTime = 10 * time.Second
    34  )
    35  
    36  func init() {
    37  	DefaultMaxHeaderBytes = 0 // use http.DefaultMaxHeaderBytes - which currently is 1 << 20 (1MB)
    38  }
    39  
    40  // ServeFunction represents a listen.Accept loop
    41  type ServeFunction = func(net.Listener) error
    42  
    43  // Server represents our graceful server
    44  type Server struct {
    45  	network              string
    46  	address              string
    47  	listener             net.Listener
    48  	wg                   sync.WaitGroup
    49  	state                state
    50  	lock                 *sync.RWMutex
    51  	BeforeBegin          func(network, address string)
    52  	OnShutdown           func()
    53  	PerWriteTimeout      time.Duration
    54  	PerWritePerKbTimeout time.Duration
    55  }
    56  
    57  // NewServer creates a server on network at provided address
    58  func NewServer(network, address, name string) *Server {
    59  	if GetManager().IsChild() {
    60  		log.Info("Restarting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
    61  	} else {
    62  		log.Info("Starting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
    63  	}
    64  	srv := &Server{
    65  		wg:                   sync.WaitGroup{},
    66  		state:                stateInit,
    67  		lock:                 &sync.RWMutex{},
    68  		network:              network,
    69  		address:              address,
    70  		PerWriteTimeout:      setting.PerWriteTimeout,
    71  		PerWritePerKbTimeout: setting.PerWritePerKbTimeout,
    72  	}
    73  
    74  	srv.BeforeBegin = func(network, addr string) {
    75  		log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid())
    76  	}
    77  
    78  	return srv
    79  }
    80  
    81  // ListenAndServe listens on the provided network address and then calls Serve
    82  // to handle requests on incoming connections.
    83  func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) error {
    84  	go srv.awaitShutdown()
    85  
    86  	listener, err := GetListener(srv.network, srv.address)
    87  	if err != nil {
    88  		log.Error("Unable to GetListener: %v", err)
    89  		return err
    90  	}
    91  
    92  	// we need to wrap the listener to take account of our lifecycle
    93  	listener = newWrappedListener(listener, srv)
    94  
    95  	// Now we need to take account of ProxyProtocol settings...
    96  	if useProxyProtocol {
    97  		listener = &proxyprotocol.Listener{
    98  			Listener:           listener,
    99  			ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
   100  			AcceptUnknown:      setting.ProxyProtocolAcceptUnknown,
   101  		}
   102  	}
   103  	srv.listener = listener
   104  
   105  	srv.BeforeBegin(srv.network, srv.address)
   106  
   107  	return srv.Serve(serve)
   108  }
   109  
   110  // ListenAndServeTLSConfig listens on the provided network address and then calls
   111  // Serve to handle requests on incoming TLS connections.
   112  func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction, useProxyProtocol, proxyProtocolTLSBridging bool) error {
   113  	go srv.awaitShutdown()
   114  
   115  	if tlsConfig.MinVersion == 0 {
   116  		tlsConfig.MinVersion = tls.VersionTLS12
   117  	}
   118  
   119  	listener, err := GetListener(srv.network, srv.address)
   120  	if err != nil {
   121  		log.Error("Unable to get Listener: %v", err)
   122  		return err
   123  	}
   124  
   125  	// we need to wrap the listener to take account of our lifecycle
   126  	listener = newWrappedListener(listener, srv)
   127  
   128  	// Now we need to take account of ProxyProtocol settings... If we're not bridging then we expect that the proxy will forward the connection to us
   129  	if useProxyProtocol && !proxyProtocolTLSBridging {
   130  		listener = &proxyprotocol.Listener{
   131  			Listener:           listener,
   132  			ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
   133  			AcceptUnknown:      setting.ProxyProtocolAcceptUnknown,
   134  		}
   135  	}
   136  
   137  	// Now handle the tls protocol
   138  	listener = tls.NewListener(listener, tlsConfig)
   139  
   140  	// Now if we're bridging then we need the proxy to tell us who we're bridging for...
   141  	if useProxyProtocol && proxyProtocolTLSBridging {
   142  		listener = &proxyprotocol.Listener{
   143  			Listener:           listener,
   144  			ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
   145  			AcceptUnknown:      setting.ProxyProtocolAcceptUnknown,
   146  		}
   147  	}
   148  
   149  	srv.listener = listener
   150  	srv.BeforeBegin(srv.network, srv.address)
   151  
   152  	return srv.Serve(serve)
   153  }
   154  
   155  // Serve accepts incoming HTTP connections on the wrapped listener l, creating a new
   156  // service goroutine for each. The service goroutines read requests and then call
   157  // handler to reply to them. Handler is typically nil, in which case the
   158  // DefaultServeMux is used.
   159  //
   160  // In addition to the standard Serve behaviour each connection is added to a
   161  // sync.Waitgroup so that all outstanding connections can be served before shutting
   162  // down the server.
   163  func (srv *Server) Serve(serve ServeFunction) error {
   164  	defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid())
   165  	srv.setState(stateRunning)
   166  	GetManager().RegisterServer()
   167  	err := serve(srv.listener)
   168  	log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid())
   169  	srv.wg.Wait()
   170  	srv.setState(stateTerminate)
   171  	GetManager().ServerDone()
   172  	// use of closed means that the listeners are closed - i.e. we should be shutting down - return nil
   173  	if err == nil || strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "http: Server closed") {
   174  		return nil
   175  	}
   176  	return err
   177  }
   178  
   179  func (srv *Server) getState() state {
   180  	srv.lock.RLock()
   181  	defer srv.lock.RUnlock()
   182  
   183  	return srv.state
   184  }
   185  
   186  func (srv *Server) setState(st state) {
   187  	srv.lock.Lock()
   188  	defer srv.lock.Unlock()
   189  
   190  	srv.state = st
   191  }
   192  
   193  type filer interface {
   194  	File() (*os.File, error)
   195  }
   196  
   197  type wrappedListener struct {
   198  	net.Listener
   199  	stopped bool
   200  	server  *Server
   201  }
   202  
   203  func newWrappedListener(l net.Listener, srv *Server) *wrappedListener {
   204  	return &wrappedListener{
   205  		Listener: l,
   206  		server:   srv,
   207  	}
   208  }
   209  
   210  func (wl *wrappedListener) Accept() (net.Conn, error) {
   211  	var c net.Conn
   212  	// Set keepalive on TCPListeners connections.
   213  	if tcl, ok := wl.Listener.(*net.TCPListener); ok {
   214  		tc, err := tcl.AcceptTCP()
   215  		if err != nil {
   216  			return nil, err
   217  		}
   218  		_ = tc.SetKeepAlive(true)                  // see http.tcpKeepAliveListener
   219  		_ = tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
   220  		c = tc
   221  	} else {
   222  		var err error
   223  		c, err = wl.Listener.Accept()
   224  		if err != nil {
   225  			return nil, err
   226  		}
   227  	}
   228  
   229  	closed := int32(0)
   230  
   231  	c = &wrappedConn{
   232  		Conn:                 c,
   233  		server:               wl.server,
   234  		closed:               &closed,
   235  		perWriteTimeout:      wl.server.PerWriteTimeout,
   236  		perWritePerKbTimeout: wl.server.PerWritePerKbTimeout,
   237  	}
   238  
   239  	wl.server.wg.Add(1)
   240  	return c, nil
   241  }
   242  
   243  func (wl *wrappedListener) Close() error {
   244  	if wl.stopped {
   245  		return syscall.EINVAL
   246  	}
   247  
   248  	wl.stopped = true
   249  	return wl.Listener.Close()
   250  }
   251  
   252  func (wl *wrappedListener) File() (*os.File, error) {
   253  	// returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes
   254  	return wl.Listener.(filer).File()
   255  }
   256  
   257  type wrappedConn struct {
   258  	net.Conn
   259  	server               *Server
   260  	closed               *int32
   261  	deadline             time.Time
   262  	perWriteTimeout      time.Duration
   263  	perWritePerKbTimeout time.Duration
   264  }
   265  
   266  func (w *wrappedConn) Write(p []byte) (n int, err error) {
   267  	if w.perWriteTimeout > 0 {
   268  		minTimeout := time.Duration(len(p)/1024) * w.perWritePerKbTimeout
   269  		minDeadline := time.Now().Add(minTimeout).Add(w.perWriteTimeout)
   270  
   271  		w.deadline = w.deadline.Add(minTimeout)
   272  		if minDeadline.After(w.deadline) {
   273  			w.deadline = minDeadline
   274  		}
   275  		_ = w.Conn.SetWriteDeadline(w.deadline)
   276  	}
   277  	return w.Conn.Write(p)
   278  }
   279  
   280  func (w *wrappedConn) Close() error {
   281  	if atomic.CompareAndSwapInt32(w.closed, 0, 1) {
   282  		defer func() {
   283  			if err := recover(); err != nil {
   284  				select {
   285  				case <-GetManager().IsHammer():
   286  					// Likely deadlocked request released at hammertime
   287  					log.Warn("Panic during connection close! %v. Likely there has been a deadlocked request which has been released by forced shutdown.", err)
   288  				default:
   289  					log.Error("Panic during connection close! %v", err)
   290  				}
   291  			}
   292  		}()
   293  		w.server.wg.Done()
   294  	}
   295  	return w.Conn.Close()
   296  }