github.com/astaxie/beego@v1.12.3/grace/server.go (about)

     1  package grace
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"log"
    10  	"net"
    11  	"net/http"
    12  	"os"
    13  	"os/exec"
    14  	"os/signal"
    15  	"strings"
    16  	"syscall"
    17  	"time"
    18  )
    19  
    20  // Server embedded http.Server
    21  type Server struct {
    22  	*http.Server
    23  	ln           net.Listener
    24  	SignalHooks  map[int]map[os.Signal][]func()
    25  	sigChan      chan os.Signal
    26  	isChild      bool
    27  	state        uint8
    28  	Network      string
    29  	terminalChan chan error
    30  }
    31  
    32  // Serve accepts incoming connections on the Listener l,
    33  // creating a new service goroutine for each.
    34  // The service goroutines read requests and then call srv.Handler to reply to them.
    35  func (srv *Server) Serve() (err error) {
    36  	srv.state = StateRunning
    37  	defer func() { srv.state = StateTerminate }()
    38  
    39  	// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
    40  	// immediately return ErrServerClosed. Make sure the program doesn't exit
    41  	// and waits instead for Shutdown to return.
    42  	if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed {
    43  		log.Println(syscall.Getpid(), "Server.Serve() error:", err)
    44  		return err
    45  	}
    46  
    47  	log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.")
    48  	// wait for Shutdown to return
    49  	if shutdownErr := <-srv.terminalChan; shutdownErr != nil {
    50  		return shutdownErr
    51  	}
    52  	return
    53  }
    54  
    55  // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
    56  // to handle requests on incoming connections. If srv.Addr is blank, ":http" is
    57  // used.
    58  func (srv *Server) ListenAndServe() (err error) {
    59  	addr := srv.Addr
    60  	if addr == "" {
    61  		addr = ":http"
    62  	}
    63  
    64  	go srv.handleSignals()
    65  
    66  	srv.ln, err = srv.getListener(addr)
    67  	if err != nil {
    68  		log.Println(err)
    69  		return err
    70  	}
    71  
    72  	if srv.isChild {
    73  		process, err := os.FindProcess(os.Getppid())
    74  		if err != nil {
    75  			log.Println(err)
    76  			return err
    77  		}
    78  		err = process.Signal(syscall.SIGTERM)
    79  		if err != nil {
    80  			return err
    81  		}
    82  	}
    83  
    84  	log.Println(os.Getpid(), srv.Addr)
    85  	return srv.Serve()
    86  }
    87  
    88  // ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
    89  // Serve to handle requests on incoming TLS connections.
    90  //
    91  // Filenames containing a certificate and matching private key for the server must
    92  // be provided. If the certificate is signed by a certificate authority, the
    93  // certFile should be the concatenation of the server's certificate followed by the
    94  // CA's certificate.
    95  //
    96  // If srv.Addr is blank, ":https" is used.
    97  func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
    98  	addr := srv.Addr
    99  	if addr == "" {
   100  		addr = ":https"
   101  	}
   102  
   103  	if srv.TLSConfig == nil {
   104  		srv.TLSConfig = &tls.Config{}
   105  	}
   106  	if srv.TLSConfig.NextProtos == nil {
   107  		srv.TLSConfig.NextProtos = []string{"http/1.1"}
   108  	}
   109  
   110  	srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
   111  	srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
   112  	if err != nil {
   113  		return
   114  	}
   115  
   116  	go srv.handleSignals()
   117  
   118  	ln, err := srv.getListener(addr)
   119  	if err != nil {
   120  		log.Println(err)
   121  		return err
   122  	}
   123  	srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
   124  
   125  	if srv.isChild {
   126  		process, err := os.FindProcess(os.Getppid())
   127  		if err != nil {
   128  			log.Println(err)
   129  			return err
   130  		}
   131  		err = process.Signal(syscall.SIGTERM)
   132  		if err != nil {
   133  			return err
   134  		}
   135  	}
   136  
   137  	log.Println(os.Getpid(), srv.Addr)
   138  	return srv.Serve()
   139  }
   140  
   141  // ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
   142  // Serve to handle requests on incoming mutual TLS connections.
   143  func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
   144  	addr := srv.Addr
   145  	if addr == "" {
   146  		addr = ":https"
   147  	}
   148  
   149  	if srv.TLSConfig == nil {
   150  		srv.TLSConfig = &tls.Config{}
   151  	}
   152  	if srv.TLSConfig.NextProtos == nil {
   153  		srv.TLSConfig.NextProtos = []string{"http/1.1"}
   154  	}
   155  
   156  	srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
   157  	srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
   158  	if err != nil {
   159  		return
   160  	}
   161  	srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
   162  	pool := x509.NewCertPool()
   163  	data, err := ioutil.ReadFile(trustFile)
   164  	if err != nil {
   165  		log.Println(err)
   166  		return err
   167  	}
   168  	pool.AppendCertsFromPEM(data)
   169  	srv.TLSConfig.ClientCAs = pool
   170  	log.Println("Mutual HTTPS")
   171  	go srv.handleSignals()
   172  
   173  	ln, err := srv.getListener(addr)
   174  	if err != nil {
   175  		log.Println(err)
   176  		return err
   177  	}
   178  	srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
   179  
   180  	if srv.isChild {
   181  		process, err := os.FindProcess(os.Getppid())
   182  		if err != nil {
   183  			log.Println(err)
   184  			return err
   185  		}
   186  		err = process.Signal(syscall.SIGTERM)
   187  		if err != nil {
   188  			return err
   189  		}
   190  	}
   191  
   192  	log.Println(os.Getpid(), srv.Addr)
   193  	return srv.Serve()
   194  }
   195  
   196  // getListener either opens a new socket to listen on, or takes the acceptor socket
   197  // it got passed when restarted.
   198  func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
   199  	if srv.isChild {
   200  		var ptrOffset uint
   201  		if len(socketPtrOffsetMap) > 0 {
   202  			ptrOffset = socketPtrOffsetMap[laddr]
   203  			log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
   204  		}
   205  
   206  		f := os.NewFile(uintptr(3+ptrOffset), "")
   207  		l, err = net.FileListener(f)
   208  		if err != nil {
   209  			err = fmt.Errorf("net.FileListener error: %v", err)
   210  			return
   211  		}
   212  	} else {
   213  		l, err = net.Listen(srv.Network, laddr)
   214  		if err != nil {
   215  			err = fmt.Errorf("net.Listen error: %v", err)
   216  			return
   217  		}
   218  	}
   219  	return
   220  }
   221  
   222  type tcpKeepAliveListener struct {
   223  	*net.TCPListener
   224  }
   225  
   226  func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
   227  	tc, err := ln.AcceptTCP()
   228  	if err != nil {
   229  		return
   230  	}
   231  	tc.SetKeepAlive(true)
   232  	tc.SetKeepAlivePeriod(3 * time.Minute)
   233  	return tc, nil
   234  }
   235  
   236  // handleSignals listens for os Signals and calls any hooked in function that the
   237  // user had registered with the signal.
   238  func (srv *Server) handleSignals() {
   239  	var sig os.Signal
   240  
   241  	signal.Notify(
   242  		srv.sigChan,
   243  		hookableSignals...,
   244  	)
   245  
   246  	pid := syscall.Getpid()
   247  	for {
   248  		sig = <-srv.sigChan
   249  		srv.signalHooks(PreSignal, sig)
   250  		switch sig {
   251  		case syscall.SIGHUP:
   252  			log.Println(pid, "Received SIGHUP. forking.")
   253  			err := srv.fork()
   254  			if err != nil {
   255  				log.Println("Fork err:", err)
   256  			}
   257  		case syscall.SIGINT:
   258  			log.Println(pid, "Received SIGINT.")
   259  			srv.shutdown()
   260  		case syscall.SIGTERM:
   261  			log.Println(pid, "Received SIGTERM.")
   262  			srv.shutdown()
   263  		default:
   264  			log.Printf("Received %v: nothing i care about...\n", sig)
   265  		}
   266  		srv.signalHooks(PostSignal, sig)
   267  	}
   268  }
   269  
   270  func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
   271  	if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
   272  		return
   273  	}
   274  	for _, f := range srv.SignalHooks[ppFlag][sig] {
   275  		f()
   276  	}
   277  }
   278  
   279  // shutdown closes the listener so that no new connections are accepted. it also
   280  // starts a goroutine that will serverTimeout (stop all running requests) the server
   281  // after DefaultTimeout.
   282  func (srv *Server) shutdown() {
   283  	if srv.state != StateRunning {
   284  		return
   285  	}
   286  
   287  	srv.state = StateShuttingDown
   288  	log.Println(syscall.Getpid(), "Waiting for connections to finish...")
   289  	ctx := context.Background()
   290  	if DefaultTimeout >= 0 {
   291  		var cancel context.CancelFunc
   292  		ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
   293  		defer cancel()
   294  	}
   295  	srv.terminalChan <- srv.Server.Shutdown(ctx)
   296  }
   297  
   298  func (srv *Server) fork() (err error) {
   299  	regLock.Lock()
   300  	defer regLock.Unlock()
   301  	if runningServersForked {
   302  		return
   303  	}
   304  	runningServersForked = true
   305  
   306  	var files = make([]*os.File, len(runningServers))
   307  	var orderArgs = make([]string, len(runningServers))
   308  	for _, srvPtr := range runningServers {
   309  		f, _ := srvPtr.ln.(*net.TCPListener).File()
   310  		files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f
   311  		orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
   312  	}
   313  
   314  	log.Println(files)
   315  	path := os.Args[0]
   316  	var args []string
   317  	if len(os.Args) > 1 {
   318  		for _, arg := range os.Args[1:] {
   319  			if arg == "-graceful" {
   320  				break
   321  			}
   322  			args = append(args, arg)
   323  		}
   324  	}
   325  	args = append(args, "-graceful")
   326  	if len(runningServers) > 1 {
   327  		args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
   328  		log.Println(args)
   329  	}
   330  	cmd := exec.Command(path, args...)
   331  	cmd.Stdout = os.Stdout
   332  	cmd.Stderr = os.Stderr
   333  	cmd.ExtraFiles = files
   334  	err = cmd.Start()
   335  	if err != nil {
   336  		log.Fatalf("Restart: Failed to launch, error: %v", err)
   337  	}
   338  
   339  	return
   340  }
   341  
   342  // RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal.
   343  func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) {
   344  	if ppFlag != PreSignal && ppFlag != PostSignal {
   345  		err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal")
   346  		return
   347  	}
   348  	for _, s := range hookableSignals {
   349  		if s == sig {
   350  			srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f)
   351  			return
   352  		}
   353  	}
   354  	err = fmt.Errorf("Signal '%v' is not supported", sig)
   355  	return
   356  }