code.gitea.io/gitea@v1.22.3/modules/graceful/server.go (about)

     1  // Copyright 2019 The Gitea Authors. All rights reserved.
     2  // SPDX-License-Identifier: MIT
     3  
     4  // This code is highly inspired by endless go
     5  
     6  package graceful
     7  
     8  import (
     9  	"crypto/tls"
    10  	"net"
    11  	"os"
    12  	"strings"
    13  	"sync"
    14  	"sync/atomic"
    15  	"syscall"
    16  	"time"
    17  
    18  	"code.gitea.io/gitea/modules/log"
    19  	"code.gitea.io/gitea/modules/proxyprotocol"
    20  	"code.gitea.io/gitea/modules/setting"
    21  )
    22  
    23  // GetListener returns a net listener
    24  // This determines the implementation of net.Listener which the server will use,
    25  // so that downstreams could provide their own Listener, such as with a hidden service or a p2p network
    26  var GetListener = DefaultGetListener
    27  
    28  // ServeFunction represents a listen.Accept loop
    29  type ServeFunction = func(net.Listener) error
    30  
    31  // Server represents our graceful server
    32  type Server struct {
    33  	network              string
    34  	address              string
    35  	listener             net.Listener
    36  	wg                   sync.WaitGroup
    37  	state                state
    38  	lock                 *sync.RWMutex
    39  	BeforeBegin          func(network, address string)
    40  	OnShutdown           func()
    41  	PerWriteTimeout      time.Duration
    42  	PerWritePerKbTimeout time.Duration
    43  }
    44  
    45  // NewServer creates a server on network at provided address
    46  func NewServer(network, address, name string) *Server {
    47  	if GetManager().IsChild() {
    48  		log.Info("Restarting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
    49  	} else {
    50  		log.Info("Starting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
    51  	}
    52  	srv := &Server{
    53  		wg:                   sync.WaitGroup{},
    54  		state:                stateInit,
    55  		lock:                 &sync.RWMutex{},
    56  		network:              network,
    57  		address:              address,
    58  		PerWriteTimeout:      setting.PerWriteTimeout,
    59  		PerWritePerKbTimeout: setting.PerWritePerKbTimeout,
    60  	}
    61  
    62  	srv.BeforeBegin = func(network, addr string) {
    63  		log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid())
    64  	}
    65  
    66  	return srv
    67  }
    68  
    69  // ListenAndServe listens on the provided network address and then calls Serve
    70  // to handle requests on incoming connections.
    71  func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) error {
    72  	go srv.awaitShutdown()
    73  
    74  	listener, err := GetListener(srv.network, srv.address)
    75  	if err != nil {
    76  		log.Error("Unable to GetListener: %v", err)
    77  		return err
    78  	}
    79  
    80  	// we need to wrap the listener to take account of our lifecycle
    81  	listener = newWrappedListener(listener, srv)
    82  
    83  	// Now we need to take account of ProxyProtocol settings...
    84  	if useProxyProtocol {
    85  		listener = &proxyprotocol.Listener{
    86  			Listener:           listener,
    87  			ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
    88  			AcceptUnknown:      setting.ProxyProtocolAcceptUnknown,
    89  		}
    90  	}
    91  	srv.listener = listener
    92  
    93  	srv.BeforeBegin(srv.network, srv.address)
    94  
    95  	return srv.Serve(serve)
    96  }
    97  
    98  // ListenAndServeTLSConfig listens on the provided network address and then calls
    99  // Serve to handle requests on incoming TLS connections.
   100  func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction, useProxyProtocol, proxyProtocolTLSBridging bool) error {
   101  	go srv.awaitShutdown()
   102  
   103  	if tlsConfig.MinVersion == 0 {
   104  		tlsConfig.MinVersion = tls.VersionTLS12
   105  	}
   106  
   107  	listener, err := GetListener(srv.network, srv.address)
   108  	if err != nil {
   109  		log.Error("Unable to get Listener: %v", err)
   110  		return err
   111  	}
   112  
   113  	// we need to wrap the listener to take account of our lifecycle
   114  	listener = newWrappedListener(listener, srv)
   115  
   116  	// Now we need to take account of ProxyProtocol settings... If we're not bridging then we expect that the proxy will forward the connection to us
   117  	if useProxyProtocol && !proxyProtocolTLSBridging {
   118  		listener = &proxyprotocol.Listener{
   119  			Listener:           listener,
   120  			ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
   121  			AcceptUnknown:      setting.ProxyProtocolAcceptUnknown,
   122  		}
   123  	}
   124  
   125  	// Now handle the tls protocol
   126  	listener = tls.NewListener(listener, tlsConfig)
   127  
   128  	// Now if we're bridging then we need the proxy to tell us who we're bridging for...
   129  	if useProxyProtocol && proxyProtocolTLSBridging {
   130  		listener = &proxyprotocol.Listener{
   131  			Listener:           listener,
   132  			ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
   133  			AcceptUnknown:      setting.ProxyProtocolAcceptUnknown,
   134  		}
   135  	}
   136  
   137  	srv.listener = listener
   138  	srv.BeforeBegin(srv.network, srv.address)
   139  
   140  	return srv.Serve(serve)
   141  }
   142  
   143  // Serve accepts incoming HTTP connections on the wrapped listener l, creating a new
   144  // service goroutine for each. The service goroutines read requests and then call
   145  // handler to reply to them. Handler is typically nil, in which case the
   146  // DefaultServeMux is used.
   147  //
   148  // In addition to the standard Serve behaviour each connection is added to a
   149  // sync.Waitgroup so that all outstanding connections can be served before shutting
   150  // down the server.
   151  func (srv *Server) Serve(serve ServeFunction) error {
   152  	defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid())
   153  	srv.setState(stateRunning)
   154  	GetManager().RegisterServer()
   155  	err := serve(srv.listener)
   156  	log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid())
   157  	srv.wg.Wait()
   158  	srv.setState(stateTerminate)
   159  	GetManager().ServerDone()
   160  	// use of closed means that the listeners are closed - i.e. we should be shutting down - return nil
   161  	if err == nil || strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "http: Server closed") {
   162  		return nil
   163  	}
   164  	return err
   165  }
   166  
   167  func (srv *Server) getState() state {
   168  	srv.lock.RLock()
   169  	defer srv.lock.RUnlock()
   170  
   171  	return srv.state
   172  }
   173  
   174  func (srv *Server) setState(st state) {
   175  	srv.lock.Lock()
   176  	defer srv.lock.Unlock()
   177  
   178  	srv.state = st
   179  }
   180  
   181  type filer interface {
   182  	File() (*os.File, error)
   183  }
   184  
   185  type wrappedListener struct {
   186  	net.Listener
   187  	stopped bool
   188  	server  *Server
   189  }
   190  
   191  func newWrappedListener(l net.Listener, srv *Server) *wrappedListener {
   192  	return &wrappedListener{
   193  		Listener: l,
   194  		server:   srv,
   195  	}
   196  }
   197  
   198  func (wl *wrappedListener) Accept() (net.Conn, error) {
   199  	var c net.Conn
   200  	// Set keepalive on TCPListeners connections.
   201  	if tcl, ok := wl.Listener.(*net.TCPListener); ok {
   202  		tc, err := tcl.AcceptTCP()
   203  		if err != nil {
   204  			return nil, err
   205  		}
   206  		_ = tc.SetKeepAlive(true)                  // see http.tcpKeepAliveListener
   207  		_ = tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
   208  		c = tc
   209  	} else {
   210  		var err error
   211  		c, err = wl.Listener.Accept()
   212  		if err != nil {
   213  			return nil, err
   214  		}
   215  	}
   216  
   217  	closed := int32(0)
   218  
   219  	c = &wrappedConn{
   220  		Conn:                 c,
   221  		server:               wl.server,
   222  		closed:               &closed,
   223  		perWriteTimeout:      wl.server.PerWriteTimeout,
   224  		perWritePerKbTimeout: wl.server.PerWritePerKbTimeout,
   225  	}
   226  
   227  	wl.server.wg.Add(1)
   228  	return c, nil
   229  }
   230  
   231  func (wl *wrappedListener) Close() error {
   232  	if wl.stopped {
   233  		return syscall.EINVAL
   234  	}
   235  
   236  	wl.stopped = true
   237  	return wl.Listener.Close()
   238  }
   239  
   240  func (wl *wrappedListener) File() (*os.File, error) {
   241  	// returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes
   242  	return wl.Listener.(filer).File()
   243  }
   244  
   245  type wrappedConn struct {
   246  	net.Conn
   247  	server               *Server
   248  	closed               *int32
   249  	deadline             time.Time
   250  	perWriteTimeout      time.Duration
   251  	perWritePerKbTimeout time.Duration
   252  }
   253  
   254  func (w *wrappedConn) Write(p []byte) (n int, err error) {
   255  	if w.perWriteTimeout > 0 {
   256  		minTimeout := time.Duration(len(p)/1024) * w.perWritePerKbTimeout
   257  		minDeadline := time.Now().Add(minTimeout).Add(w.perWriteTimeout)
   258  
   259  		w.deadline = w.deadline.Add(minTimeout)
   260  		if minDeadline.After(w.deadline) {
   261  			w.deadline = minDeadline
   262  		}
   263  		_ = w.Conn.SetWriteDeadline(w.deadline)
   264  	}
   265  	return w.Conn.Write(p)
   266  }
   267  
   268  func (w *wrappedConn) Close() error {
   269  	if atomic.CompareAndSwapInt32(w.closed, 0, 1) {
   270  		defer func() {
   271  			if err := recover(); err != nil {
   272  				select {
   273  				case <-GetManager().IsHammer():
   274  					// Likely deadlocked request released at hammertime
   275  					log.Warn("Panic during connection close! %v. Likely there has been a deadlocked request which has been released by forced shutdown.", err)
   276  				default:
   277  					log.Error("Panic during connection close! %v", err)
   278  				}
   279  			}
   280  		}()
   281  		w.server.wg.Done()
   282  	}
   283  	return w.Conn.Close()
   284  }