github.com/wfusion/gofusion@v1.1.14/http/gracefully/endless.go (about)

     1  // fork from github.com/fvbock/endless@v0.0.0-20170109170031-447134032cb6
     2  // modified:
     3  // 1. support windows signals
     4  // 2. log content
     5  // 3. close by http.Serve.ShutDown() rather than listener.Close()
     6  // 4. make sure Serve() exit after Shutdown() triggered by signals
     7  // 5. implement *net.TcpConn all public methods
     8  
     9  package gracefully
    10  
    11  import (
    12  	"context"
    13  	"crypto/tls"
    14  	"errors"
    15  	"fmt"
    16  	"io"
    17  	"log"
    18  	"net"
    19  	"net/http"
    20  	"os"
    21  	"os/exec"
    22  	"runtime"
    23  	"strings"
    24  	"sync"
    25  	"syscall"
    26  	"time"
    27  
    28  	"github.com/wfusion/gofusion/common/utils"
    29  	"github.com/wfusion/gofusion/routine"
    30  )
    31  
    32  const (
    33  	PreSignal = iota
    34  	PostSignal
    35  
    36  	StateInit
    37  	StateRunning
    38  	StateShuttingDown
    39  	StateTerminate
    40  )
    41  
    42  var (
    43  	DefaultReadTimeOut    time.Duration
    44  	DefaultWriteTimeOut   time.Duration
    45  	DefaultMaxHeaderBytes int
    46  	DefaultHammerTime     time.Duration
    47  
    48  	runningServerReg     sync.RWMutex
    49  	runningServers       map[string]*endlessServer
    50  	runningServersOrder  []string
    51  	socketPtrOffsetMap   map[string]uint
    52  	runningServersForked bool
    53  
    54  	isChild     bool
    55  	socketOrder string
    56  )
    57  
    58  func init() {
    59  	runningServerReg = sync.RWMutex{}
    60  	runningServers = make(map[string]*endlessServer)
    61  	runningServersOrder = []string{}
    62  	socketPtrOffsetMap = make(map[string]uint)
    63  
    64  	DefaultMaxHeaderBytes = 0 // use http.DefaultMaxHeaderBytes - which currently is 1 << 20 (1MB)
    65  
    66  	// after a restart the parent will finish ongoing requests before
    67  	// shutting down. set to a negative value to disable
    68  	DefaultHammerTime = 60 * time.Second
    69  }
    70  
    71  type endlessServer struct {
    72  	*http.Server
    73  	SignalHooks map[int]map[os.Signal][]func()
    74  	BeforeBegin func(addr string)
    75  	AppName     string
    76  
    77  	endlessListener  net.Listener
    78  	tlsInnerListener *endlessListener
    79  	close            chan struct{}
    80  	wg               *sync.WaitGroup
    81  	sigChan          chan os.Signal
    82  	isChild          bool
    83  	state            uint8
    84  	lock             *sync.RWMutex
    85  }
    86  
    87  // NewServer returns an initialized endlessServer Object. Calling Serve on it will
    88  // actually "start" the server.
    89  func NewServer(appName string, handler http.Handler, addr string, nextProtos []string) (srv *endlessServer) {
    90  	runningServerReg.Lock()
    91  	defer runningServerReg.Unlock()
    92  
    93  	socketOrder = os.Getenv("ENDLESS_SOCKET_ORDER")
    94  	isChild = os.Getenv("ENDLESS_CONTINUE") != ""
    95  
    96  	if len(socketOrder) > 0 {
    97  		for i, addr := range strings.Split(socketOrder, ",") {
    98  			socketPtrOffsetMap[addr] = uint(i)
    99  		}
   100  	} else {
   101  		socketPtrOffsetMap[addr] = uint(len(runningServersOrder))
   102  	}
   103  
   104  	srv = &endlessServer{
   105  		AppName: appName,
   106  		Server: &http.Server{
   107  			Addr:           addr,
   108  			ReadTimeout:    DefaultReadTimeOut,
   109  			WriteTimeout:   DefaultWriteTimeOut,
   110  			MaxHeaderBytes: DefaultMaxHeaderBytes,
   111  			Handler:        handler,
   112  			TLSConfig:      &tls.Config{NextProtos: nextProtos},
   113  		},
   114  		close:       make(chan struct{}),
   115  		wg:          new(sync.WaitGroup),
   116  		sigChan:     make(chan os.Signal),
   117  		isChild:     isChild,
   118  		SignalHooks: newSignalHookFunc(),
   119  		state:       StateInit,
   120  		lock:        new(sync.RWMutex),
   121  	}
   122  
   123  	runningServersOrder = append(runningServersOrder, addr)
   124  	runningServers[addr] = srv
   125  
   126  	return
   127  }
   128  
   129  // ListenAndServe listens on the TCP network address addr and then calls Serve
   130  // with handler to handle requests on incoming connections. Handler is typically
   131  // nil, in which case the DefaultServeMux is used.
   132  func ListenAndServe(appName string, handler http.Handler, addr string, nextProtos []string) error {
   133  	server := NewServer(appName, handler, addr, nextProtos)
   134  	return server.ListenAndServe()
   135  }
   136  
   137  // ListenAndServeTLS acts identically to ListenAndServe, except that it expects
   138  // HTTPS connections. Additionally, files containing a certificate and matching
   139  // private key for the server must be provided. If the certificate is signed by a
   140  // certificate authority, the certFile should be the concatenation of the server's
   141  // certificate followed by the CA's certificate.
   142  func ListenAndServeTLS(appName string, handler http.Handler, addr, certFile, keyFile string,
   143  	nextProtos []string) error {
   144  	server := NewServer(appName, handler, addr, nextProtos)
   145  	return server.ListenAndServeTLS(certFile, keyFile)
   146  }
   147  
   148  // Serve accepts incoming HTTP connections on the listener l, creating a new
   149  // service goroutine for each. The service goroutines read requests and then call
   150  // handler to reply to them. Handler is typically nil, in which case the
   151  // DefaultServeMux is used.
   152  //
   153  // In addition to the stl Serve behaviour each connection is added to a
   154  // sync.WaitGroup so that all outstanding connections can be served before shutting
   155  // down the server.
   156  func (e *endlessServer) Serve() (err error) {
   157  	defer log.Println(syscall.Getpid(), "[Common] endless exited.")
   158  
   159  	e.setState(StateRunning)
   160  	log.Println(syscall.Getpid(), "[Common] endless listening", e.endlessListener.Addr())
   161  
   162  	// ignore server closed error because it happened when we call Server.Shutdown or Server.Close
   163  	if err = e.Server.Serve(e.endlessListener); err != nil {
   164  		// http: Server closed
   165  		// use of closed network connection
   166  		if errors.Is(err, http.ErrServerClosed) || isClosedConnError(err) {
   167  			err = nil
   168  		}
   169  	}
   170  	log.Println(syscall.Getpid(), "[Common] endless waiting for connections to finish...")
   171  	e.wg.Wait()
   172  	e.setState(StateTerminate)
   173  
   174  	<-e.close
   175  
   176  	return
   177  }
   178  
   179  // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
   180  // to handle requests on incoming connections. If srv.Addr is blank, ":http" is
   181  // used.
   182  func (e *endlessServer) ListenAndServe() (err error) {
   183  	addr := e.Addr
   184  	if addr == "" {
   185  		addr = ":http"
   186  	}
   187  
   188  	if err = setupHTTP2_Serve(e.Server); err != nil {
   189  		return
   190  	}
   191  
   192  	routine.Go(e.handleSignals, routine.AppName(e.AppName))
   193  
   194  	l, err := e.getListener(addr)
   195  	if err != nil {
   196  		log.Println(syscall.Getpid(), "[Common] endless", err)
   197  		return
   198  	}
   199  
   200  	e.endlessListener = newEndlessListener(l, e)
   201  	if e.isChild {
   202  		_ = syscallKill(syscall.Getppid())
   203  	}
   204  
   205  	if e.BeforeBegin != nil {
   206  		e.BeforeBegin(e.Addr)
   207  	}
   208  
   209  	return e.Serve()
   210  }
   211  
   212  // ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
   213  // Serve to handle requests on incoming TLS connections.
   214  //
   215  // Filenames containing a certificate and matching private key for the server must
   216  // be provided. If the certificate is signed by a certificate authority, the
   217  // certFile should be the concatenation of the server's certificate followed by the
   218  // CA's certificate.
   219  //
   220  // If srv.Addr is blank, ":https" is used.
   221  func (e *endlessServer) ListenAndServeTLS(certFile, keyFile string) (err error) {
   222  	addr := e.Addr
   223  	if addr == "" {
   224  		addr = ":https"
   225  	}
   226  
   227  	// Setup HTTP/2 before srv.Serve, to initialize srv.TLSConfig
   228  	// before we clone it and create the TLS Listener.
   229  	if err = setupHTTP2_ServeTLS(e.Server); err != nil {
   230  		return
   231  	}
   232  
   233  	config := new(tls.Config)
   234  	if e.Server.TLSConfig != nil {
   235  		*config = *e.Server.TLSConfig.Clone()
   236  	}
   237  	if !utils.NewSet(config.NextProtos...).Contains("http/1.1") {
   238  		config.NextProtos = append(config.NextProtos, "http/1.1")
   239  	}
   240  
   241  	configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil
   242  	if !configHasCert || certFile != "" || keyFile != "" {
   243  		config.Certificates = make([]tls.Certificate, 1)
   244  		config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
   245  		if err != nil {
   246  			return err
   247  		}
   248  	}
   249  
   250  	routine.Go(e.handleSignals, routine.AppName(e.AppName))
   251  
   252  	l, err := e.getListener(addr)
   253  	if err != nil {
   254  		log.Println(syscall.Getpid(), "[Common] endless error occur when get listener:", err)
   255  		return
   256  	}
   257  
   258  	e.tlsInnerListener = newEndlessListener(l, e)
   259  	e.endlessListener = tls.NewListener(e.tlsInnerListener, config)
   260  	if e.isChild {
   261  		_ = syscallKill(syscall.Getppid())
   262  	}
   263  
   264  	return e.Serve()
   265  }
   266  
   267  // Shutdown closes the listener so that none new connections are accepted. it also
   268  // starts a goroutine that will hammer (stop all running requests) the server
   269  // after DefaultHammerTime.
   270  func (e *endlessServer) Shutdown() {
   271  	// make sure server Shutdown & log printed before Serve() return
   272  	defer func() {
   273  		e.lock.Lock()
   274  		defer e.lock.Unlock()
   275  		if _, ok := utils.IsChannelClosed(e.close); ok {
   276  			return
   277  		}
   278  		if e.close != nil {
   279  			close(e.close)
   280  		}
   281  	}()
   282  
   283  	if e.getState() != StateRunning {
   284  		return
   285  	}
   286  
   287  	e.setState(StateShuttingDown)
   288  	if DefaultHammerTime >= 0 {
   289  		routine.Loop(e.hammerTime, routine.Args(DefaultHammerTime), routine.AppName(e.AppName))
   290  	}
   291  	// disable keep-alive on existing connections
   292  	e.Server.SetKeepAlivesEnabled(false)
   293  
   294  	// TODO: new context with timeout because system may forcefully kill the program
   295  	if err := e.Server.Shutdown(context.TODO()); err != nil {
   296  		log.Println(syscall.Getpid(), "[Common] endless close listener error:", err)
   297  	} else {
   298  		log.Println(syscall.Getpid(), "[Common] endless", e.endlessListener.Addr(), "listener closed.")
   299  	}
   300  }
   301  
   302  // RegisterSignalHook registers a function to be run PreSignal or PostSignal for
   303  // a given signal. PRE or POST in this case means before or after the signal
   304  // related code endless itself runs
   305  func (e *endlessServer) RegisterSignalHook(prePost int, sig os.Signal, f func()) (err error) {
   306  	if prePost != PreSignal && prePost != PostSignal {
   307  		err = fmt.Errorf("cannot use %v for prePost arg. Must be endless.PRE_SIGNAL or endless.POST_SIGNAL", sig)
   308  		return
   309  	}
   310  	for _, s := range hookableSignals {
   311  		if s == sig {
   312  			e.SignalHooks[prePost][sig] = append(e.SignalHooks[prePost][sig], f)
   313  			return
   314  		}
   315  	}
   316  	err = fmt.Errorf("signal %v is not supported", sig)
   317  	return
   318  }
   319  
   320  // getListener either opens a new socket to listen on, or takes the acceptor socket
   321  // it got passed when restarted.
   322  func (e *endlessServer) getListener(addr string) (l net.Listener, err error) {
   323  	if e.isChild {
   324  		ptrOffset := uint(0)
   325  		runningServerReg.RLock()
   326  		defer runningServerReg.RUnlock()
   327  		if len(socketPtrOffsetMap) > 0 {
   328  			ptrOffset = socketPtrOffsetMap[addr]
   329  			log.Println(syscall.Getpid(), "[Common] endless addr:", addr, "ptr offset:", socketPtrOffsetMap[addr])
   330  		}
   331  
   332  		f := os.NewFile(uintptr(3+ptrOffset), "")
   333  		l, err = net.FileListener(f)
   334  		if err != nil {
   335  			err = fmt.Errorf("net.FileListener error: %v", err)
   336  			return
   337  		}
   338  	} else {
   339  		l, err = net.Listen("tcp", addr)
   340  		if err != nil {
   341  			err = fmt.Errorf("net.Listen error: %v", err)
   342  			return
   343  		}
   344  	}
   345  	return
   346  }
   347  
   348  func (e *endlessServer) signalHooks(ppFlag int, sig os.Signal) {
   349  	if _, notSet := e.SignalHooks[ppFlag][sig]; !notSet {
   350  		return
   351  	}
   352  	for _, f := range e.SignalHooks[ppFlag][sig] {
   353  		f()
   354  	}
   355  }
   356  
   357  // hammerTime forces the server to shut down in a given timeout - whether it
   358  // finished outstanding requests or not. if Read/WriteTimeout are not set or the
   359  // max header size is very big a connection could hang...
   360  //
   361  // srv.Serve() will not return until all connections are served. this will
   362  // unblock the srv.wg.Wait() in Serve() thus causing ListenAndServe(TLS) to
   363  // return.
   364  func (e *endlessServer) hammerTime(d time.Duration) {
   365  	defer func() {
   366  		// we are calling e.wg.Done() until it panics which means we called
   367  		// Done() when the counter was already at 0, and we're done.
   368  		// (and thus Serve() will return and the parent will exit)
   369  		if r := recover(); r != nil {
   370  			log.Println(syscall.Getpid(), "[Common] endless wait group at 0", r)
   371  		}
   372  	}()
   373  	if e.getState() != StateShuttingDown {
   374  		return
   375  	}
   376  	time.Sleep(d)
   377  	log.Println(syscall.Getpid(), "[Common] endless harmerTime() forcefully shutting down parent")
   378  	for {
   379  		if e.getState() == StateTerminate {
   380  			break
   381  		}
   382  		e.wg.Done()
   383  		runtime.Gosched()
   384  	}
   385  }
   386  
   387  func (e *endlessServer) fork() (err error) {
   388  	runningServerReg.Lock()
   389  	defer runningServerReg.Unlock()
   390  
   391  	// only one server instance should fork!
   392  	if runningServersForked {
   393  		return errors.New("another process already forked, ignoring this one")
   394  	}
   395  
   396  	runningServersForked = true
   397  
   398  	var files = make([]*os.File, len(runningServers))
   399  	var orderArgs = make([]string, len(runningServers))
   400  	// get the accessor socket fds for _all_ server instances
   401  	for _, srvPtr := range runningServers {
   402  		// introspect.PrintTypeDump(srvPtr.endlessListener)
   403  		switch srvPtr.endlessListener.(type) {
   404  		case *endlessListener:
   405  			// normal listener
   406  			files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.endlessListener.(*endlessListener).File()
   407  		default:
   408  			// tls listener
   409  			files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
   410  		}
   411  		orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
   412  	}
   413  
   414  	env := append(
   415  		os.Environ(),
   416  		"ENDLESS_CONTINUE=1",
   417  	)
   418  	if len(runningServers) > 1 {
   419  		env = append(env, fmt.Sprintf(`ENDLESS_SOCKET_ORDER=%s`, strings.Join(orderArgs, ",")))
   420  	}
   421  
   422  	path := os.Args[0]
   423  	var args []string
   424  	if len(os.Args) > 1 {
   425  		args = os.Args[1:]
   426  	}
   427  
   428  	cmd := exec.Command(path, args...)
   429  	cmd.Stdout = os.Stdout
   430  	cmd.Stderr = os.Stderr
   431  	cmd.ExtraFiles = files
   432  	cmd.Env = env
   433  
   434  	if err = cmd.Start(); err != nil {
   435  		log.Fatalf("%v [Common] endless restart: failed to launch, error: %v", syscall.Getpid(), err)
   436  	}
   437  
   438  	return
   439  }
   440  
   441  func (e *endlessServer) getState() uint8 {
   442  	e.lock.RLock()
   443  	defer e.lock.RUnlock()
   444  
   445  	return e.state
   446  }
   447  
   448  func (e *endlessServer) setState(st uint8) {
   449  	e.lock.Lock()
   450  	defer e.lock.Unlock()
   451  
   452  	e.state = st
   453  }
   454  
   455  type endlessListener struct {
   456  	net.Listener
   457  	stopped bool
   458  	server  *endlessServer
   459  }
   460  
   461  func newEndlessListener(l net.Listener, srv *endlessServer) (el *endlessListener) {
   462  	return &endlessListener{
   463  		Listener: l,
   464  		server:   srv,
   465  	}
   466  }
   467  
   468  func (e *endlessListener) Accept() (c net.Conn, err error) {
   469  	tc, err := e.Listener.(*net.TCPListener).AcceptTCP()
   470  	if err != nil {
   471  		return
   472  	}
   473  
   474  	// see net/http.tcpKeepAliveListener
   475  	_ = tc.SetKeepAlive(true)
   476  	// see net/http.tcpKeepAliveListener
   477  	_ = tc.SetKeepAlivePeriod(3 * time.Minute)
   478  
   479  	c = &endlessConn{
   480  		Conn:   tc,
   481  		server: e.server,
   482  	}
   483  
   484  	e.server.wg.Add(1)
   485  	return
   486  }
   487  
   488  func (e *endlessListener) File() *os.File {
   489  	// returns a dup(2) - FD_CLOEXEC flag *not* set
   490  	tl := e.Listener.(*net.TCPListener)
   491  	fl, _ := tl.File()
   492  	return fl
   493  }
   494  
   495  type endlessConn struct {
   496  	net.Conn
   497  	doneOnce sync.Once
   498  	server   *endlessServer
   499  }
   500  
   501  // Read reads data from the connection.
   502  // Read can be made to time out and return an error after a fixed
   503  // time limit; see SetDeadline and SetReadDeadline.
   504  func (e *endlessConn) Read(b []byte) (n int, err error) { return e.Conn.Read(b) }
   505  
   506  // Write writes data to the connection.
   507  // Write can be made to time out and return an error after a fixed
   508  // time limit; see SetDeadline and SetWriteDeadline.
   509  func (e *endlessConn) Write(b []byte) (n int, err error) { return e.Conn.Write(b) }
   510  
   511  // Close closes the connection.
   512  // Any blocked Read or Write operations will be unblocked and return errors.
   513  func (e *endlessConn) Close() (err error) {
   514  	defer e.doneOnce.Do(func() {
   515  		e.server.wg.Done()
   516  	})
   517  	return e.Conn.Close()
   518  }
   519  
   520  // LocalAddr returns the local network address, if known.
   521  func (e *endlessConn) LocalAddr() net.Addr { return e.Conn.LocalAddr() }
   522  
   523  // RemoteAddr returns the remote network address, if known.
   524  func (e *endlessConn) RemoteAddr() net.Addr { return e.Conn.RemoteAddr() }
   525  
   526  // SetDeadline sets the read and write deadlines associated
   527  // with the connection. It is equivalent to calling both
   528  // SetReadDeadline and SetWriteDeadline.
   529  //
   530  // A deadline is an absolute time after which I/O operations
   531  // fail instead of blocking. The deadline applies to all future
   532  // and pending I/O, not just the immediately following call to
   533  // Read or Write. After a deadline has been exceeded, the
   534  // connection can be refreshed by setting a deadline in the future.
   535  //
   536  // If the deadline is exceeded a call to Read or Write or to other
   537  // I/O methods will return an error that wraps os.ErrDeadlineExceeded.
   538  // This can be tested using errors.Is(err, os.ErrDeadlineExceeded).
   539  // The error's Timeout method will return true, but note that there
   540  // are other possible errors for which the Timeout method will
   541  // return true even if the deadline has not been exceeded.
   542  //
   543  // An idle timeout can be implemented by repeatedly extending
   544  // the deadline after successful Read or Write calls.
   545  //
   546  // A zero value for t means I/O operations will not time out.
   547  func (e *endlessConn) SetDeadline(t time.Time) error { return e.Conn.SetDeadline(t) }
   548  
   549  // SetReadDeadline sets the deadline for future Read calls
   550  // and any currently-blocked Read call.
   551  // A zero value for t means Read will not time out.
   552  func (e *endlessConn) SetReadDeadline(t time.Time) error { return e.Conn.SetReadDeadline(t) }
   553  
   554  // SetWriteDeadline sets the deadline for future Write calls
   555  // and any currently-blocked Write call.
   556  // Even if write times out, it may return n > 0, indicating that
   557  // some of the data was successfully written.
   558  // A zero value for t means Write will not time out.
   559  func (e *endlessConn) SetWriteDeadline(t time.Time) error { return e.Conn.SetWriteDeadline(t) }
   560  
   561  // SyscallConn returns a raw network connection.
   562  // This implements the syscall.Conn interface.
   563  func (e *endlessConn) SyscallConn() (syscall.RawConn, error) {
   564  	return e.Conn.(*net.TCPConn).SyscallConn()
   565  }
   566  
   567  // ReadFrom implements the io.ReaderFrom ReadFrom method.
   568  func (e *endlessConn) ReadFrom(r io.Reader) (int64, error) {
   569  	return e.Conn.(*net.TCPConn).ReadFrom(r)
   570  }
   571  
   572  // SetLinger sets the behavior of Close on a connection which still
   573  // has data waiting to be sent or to be acknowledged.
   574  //
   575  // If sec < 0 (the default), the operating system finishes sending the
   576  // data in the background.
   577  //
   578  // If sec == 0, the operating system discards any unsent or
   579  // unacknowledged data.
   580  //
   581  // If sec > 0, the data is sent in the background as with sec < 0. On
   582  // some operating systems after sec seconds have elapsed any remaining
   583  // unsent data may be discarded.
   584  func (e *endlessConn) SetLinger(sec int) error {
   585  	return e.Conn.(*net.TCPConn).SetLinger(sec)
   586  }
   587  
   588  // SetKeepAlive sets whether the operating system should send
   589  // keep-alive messages on the connection.
   590  func (e *endlessConn) SetKeepAlive(keepalive bool) error {
   591  	return e.Conn.(*net.TCPConn).SetKeepAlive(keepalive)
   592  }
   593  
   594  // SetKeepAlivePeriod sets period between keep-alives.
   595  func (e *endlessConn) SetKeepAlivePeriod(d time.Duration) error {
   596  	return e.Conn.(*net.TCPConn).SetKeepAlivePeriod(d)
   597  }
   598  
   599  // SetNoDelay controls whether the operating system should delay
   600  // packet transmission in hopes of sending fewer packets (Nagle's
   601  // algorithm).  The default is true (no delay), meaning that data is
   602  // sent as soon as possible after a Write.
   603  func (e *endlessConn) SetNoDelay(noDelay bool) error {
   604  	return e.Conn.(*net.TCPConn).SetNoDelay(noDelay)
   605  }