github.com/gitbundle/modules@v0.0.0-20231025071548-85b91c5c3b01/graceful/server.go (about)

     1  // Copyright 2023 The GitBundle Inc. All rights reserved.
     2  // Copyright 2017 The Gitea Authors. All rights reserved.
     3  // Use of this source code is governed by a MIT-style
     4  // license that can be found in the LICENSE file.
     5  
     6  // This code is highly inspired by endless go
     7  
     8  package graceful
     9  
    10  import (
    11  	"crypto/tls"
    12  	"net"
    13  	"os"
    14  	"strings"
    15  	"sync"
    16  	"sync/atomic"
    17  	"syscall"
    18  	"time"
    19  
    20  	"github.com/gitbundle/modules/log"
    21  	"github.com/gitbundle/modules/setting"
    22  )
    23  
    24  var (
    25  	// DefaultReadTimeOut default read timeout
    26  	DefaultReadTimeOut time.Duration
    27  	// DefaultWriteTimeOut default write timeout
    28  	DefaultWriteTimeOut time.Duration
    29  	// DefaultMaxHeaderBytes default max header bytes
    30  	DefaultMaxHeaderBytes int
    31  	// PerWriteWriteTimeout timeout for writes
    32  	PerWriteWriteTimeout = 30 * time.Second
    33  	// PerWriteWriteTimeoutKbTime is a timeout taking account of how much there is to be written
    34  	PerWriteWriteTimeoutKbTime = 10 * time.Second
    35  )
    36  
    37  func init() {
    38  	DefaultMaxHeaderBytes = 0 // use http.DefaultMaxHeaderBytes - which currently is 1 << 20 (1MB)
    39  }
    40  
    41  // ServeFunction represents a listen.Accept loop
    42  type ServeFunction = func(net.Listener) error
    43  
    44  // Server represents our graceful server
    45  type Server struct {
    46  	network              string
    47  	address              string
    48  	listener             net.Listener
    49  	wg                   sync.WaitGroup
    50  	state                state
    51  	lock                 *sync.RWMutex
    52  	BeforeBegin          func(network, address string)
    53  	OnShutdown           func()
    54  	PerWriteTimeout      time.Duration
    55  	PerWritePerKbTimeout time.Duration
    56  }
    57  
    58  // NewServer creates a server on network at provided address
    59  func NewServer(network, address, name string) *Server {
    60  	if GetManager().IsChild() {
    61  		log.Info("Restarting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
    62  	} else {
    63  		log.Info("Starting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
    64  	}
    65  	srv := &Server{
    66  		wg:                   sync.WaitGroup{},
    67  		state:                stateInit,
    68  		lock:                 &sync.RWMutex{},
    69  		network:              network,
    70  		address:              address,
    71  		PerWriteTimeout:      setting.PerWriteTimeout,
    72  		PerWritePerKbTimeout: setting.PerWritePerKbTimeout,
    73  	}
    74  
    75  	srv.BeforeBegin = func(network, addr string) {
    76  		log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid())
    77  	}
    78  
    79  	return srv
    80  }
    81  
    82  // ListenAndServe listens on the provided network address and then calls Serve
    83  // to handle requests on incoming connections.
    84  func (srv *Server) ListenAndServe(serve ServeFunction) error {
    85  	go srv.awaitShutdown()
    86  
    87  	l, err := GetListener(srv.network, srv.address)
    88  	if err != nil {
    89  		log.Error("Unable to GetListener: %v", err)
    90  		return err
    91  	}
    92  
    93  	srv.listener = newWrappedListener(l, srv)
    94  
    95  	srv.BeforeBegin(srv.network, srv.address)
    96  
    97  	return srv.Serve(serve)
    98  }
    99  
   100  // ListenAndServeTLSConfig listens on the provided network address and then calls
   101  // Serve to handle requests on incoming TLS connections.
   102  func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction) error {
   103  	go srv.awaitShutdown()
   104  
   105  	if tlsConfig.MinVersion == 0 {
   106  		tlsConfig.MinVersion = tls.VersionTLS12
   107  	}
   108  
   109  	l, err := GetListener(srv.network, srv.address)
   110  	if err != nil {
   111  		log.Error("Unable to get Listener: %v", err)
   112  		return err
   113  	}
   114  
   115  	wl := newWrappedListener(l, srv)
   116  	srv.listener = tls.NewListener(wl, tlsConfig)
   117  
   118  	srv.BeforeBegin(srv.network, srv.address)
   119  
   120  	return srv.Serve(serve)
   121  }
   122  
   123  // Serve accepts incoming HTTP connections on the wrapped listener l, creating a new
   124  // service goroutine for each. The service goroutines read requests and then call
   125  // handler to reply to them. Handler is typically nil, in which case the
   126  // DefaultServeMux is used.
   127  //
   128  // In addition to the standard Serve behaviour each connection is added to a
   129  // sync.Waitgroup so that all outstanding connections can be served before shutting
   130  // down the server.
   131  func (srv *Server) Serve(serve ServeFunction) error {
   132  	defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid())
   133  	srv.setState(stateRunning)
   134  	GetManager().RegisterServer()
   135  	err := serve(srv.listener)
   136  	log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid())
   137  	srv.wg.Wait()
   138  	srv.setState(stateTerminate)
   139  	GetManager().ServerDone()
   140  	// use of closed means that the listeners are closed - i.e. we should be shutting down - return nil
   141  	if err == nil || strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "http: Server closed") {
   142  		return nil
   143  	}
   144  	return err
   145  }
   146  
   147  func (srv *Server) getState() state {
   148  	srv.lock.RLock()
   149  	defer srv.lock.RUnlock()
   150  
   151  	return srv.state
   152  }
   153  
   154  func (srv *Server) setState(st state) {
   155  	srv.lock.Lock()
   156  	defer srv.lock.Unlock()
   157  
   158  	srv.state = st
   159  }
   160  
   161  type filer interface {
   162  	File() (*os.File, error)
   163  }
   164  
   165  type wrappedListener struct {
   166  	net.Listener
   167  	stopped bool
   168  	server  *Server
   169  }
   170  
   171  func newWrappedListener(l net.Listener, srv *Server) *wrappedListener {
   172  	return &wrappedListener{
   173  		Listener: l,
   174  		server:   srv,
   175  	}
   176  }
   177  
   178  func (wl *wrappedListener) Accept() (net.Conn, error) {
   179  	var c net.Conn
   180  	// Set keepalive on TCPListeners connections.
   181  	if tcl, ok := wl.Listener.(*net.TCPListener); ok {
   182  		tc, err := tcl.AcceptTCP()
   183  		if err != nil {
   184  			return nil, err
   185  		}
   186  		_ = tc.SetKeepAlive(true)                  // see http.tcpKeepAliveListener
   187  		_ = tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
   188  		c = tc
   189  	} else {
   190  		var err error
   191  		c, err = wl.Listener.Accept()
   192  		if err != nil {
   193  			return nil, err
   194  		}
   195  	}
   196  
   197  	closed := int32(0)
   198  
   199  	c = &wrappedConn{
   200  		Conn:                 c,
   201  		server:               wl.server,
   202  		closed:               &closed,
   203  		perWriteTimeout:      wl.server.PerWriteTimeout,
   204  		perWritePerKbTimeout: wl.server.PerWritePerKbTimeout,
   205  	}
   206  
   207  	wl.server.wg.Add(1)
   208  	return c, nil
   209  }
   210  
   211  func (wl *wrappedListener) Close() error {
   212  	if wl.stopped {
   213  		return syscall.EINVAL
   214  	}
   215  
   216  	wl.stopped = true
   217  	return wl.Listener.Close()
   218  }
   219  
   220  func (wl *wrappedListener) File() (*os.File, error) {
   221  	// returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes
   222  	return wl.Listener.(filer).File()
   223  }
   224  
   225  type wrappedConn struct {
   226  	net.Conn
   227  	server               *Server
   228  	closed               *int32
   229  	deadline             time.Time
   230  	perWriteTimeout      time.Duration
   231  	perWritePerKbTimeout time.Duration
   232  }
   233  
   234  func (w *wrappedConn) Write(p []byte) (n int, err error) {
   235  	if w.perWriteTimeout > 0 {
   236  		minTimeout := time.Duration(len(p)/1024) * w.perWritePerKbTimeout
   237  		minDeadline := time.Now().Add(minTimeout).Add(w.perWriteTimeout)
   238  
   239  		w.deadline = w.deadline.Add(minTimeout)
   240  		if minDeadline.After(w.deadline) {
   241  			w.deadline = minDeadline
   242  		}
   243  		_ = w.Conn.SetWriteDeadline(w.deadline)
   244  	}
   245  	return w.Conn.Write(p)
   246  }
   247  
   248  func (w *wrappedConn) Close() error {
   249  	if atomic.CompareAndSwapInt32(w.closed, 0, 1) {
   250  		defer func() {
   251  			if err := recover(); err != nil {
   252  				select {
   253  				case <-GetManager().IsHammer():
   254  					// Likely deadlocked request released at hammertime
   255  					log.Warn("Panic during connection close! %v. Likely there has been a deadlocked request which has been released by forced shutdown.", err)
   256  				default:
   257  					log.Error("Panic during connection close! %v", err)
   258  				}
   259  			}
   260  		}()
   261  		w.server.wg.Done()
   262  	}
   263  	return w.Conn.Close()
   264  }