github.com/astaxie/beego@v1.12.3/grace/server.go (about) 1 package grace 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "fmt" 8 "io/ioutil" 9 "log" 10 "net" 11 "net/http" 12 "os" 13 "os/exec" 14 "os/signal" 15 "strings" 16 "syscall" 17 "time" 18 ) 19 20 // Server embedded http.Server 21 type Server struct { 22 *http.Server 23 ln net.Listener 24 SignalHooks map[int]map[os.Signal][]func() 25 sigChan chan os.Signal 26 isChild bool 27 state uint8 28 Network string 29 terminalChan chan error 30 } 31 32 // Serve accepts incoming connections on the Listener l, 33 // creating a new service goroutine for each. 34 // The service goroutines read requests and then call srv.Handler to reply to them. 35 func (srv *Server) Serve() (err error) { 36 srv.state = StateRunning 37 defer func() { srv.state = StateTerminate }() 38 39 // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS 40 // immediately return ErrServerClosed. Make sure the program doesn't exit 41 // and waits instead for Shutdown to return. 42 if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { 43 log.Println(syscall.Getpid(), "Server.Serve() error:", err) 44 return err 45 } 46 47 log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") 48 // wait for Shutdown to return 49 if shutdownErr := <-srv.terminalChan; shutdownErr != nil { 50 return shutdownErr 51 } 52 return 53 } 54 55 // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve 56 // to handle requests on incoming connections. If srv.Addr is blank, ":http" is 57 // used. 58 func (srv *Server) ListenAndServe() (err error) { 59 addr := srv.Addr 60 if addr == "" { 61 addr = ":http" 62 } 63 64 go srv.handleSignals() 65 66 srv.ln, err = srv.getListener(addr) 67 if err != nil { 68 log.Println(err) 69 return err 70 } 71 72 if srv.isChild { 73 process, err := os.FindProcess(os.Getppid()) 74 if err != nil { 75 log.Println(err) 76 return err 77 } 78 err = process.Signal(syscall.SIGTERM) 79 if err != nil { 80 return err 81 } 82 } 83 84 log.Println(os.Getpid(), srv.Addr) 85 return srv.Serve() 86 } 87 88 // ListenAndServeTLS listens on the TCP network address srv.Addr and then calls 89 // Serve to handle requests on incoming TLS connections. 90 // 91 // Filenames containing a certificate and matching private key for the server must 92 // be provided. If the certificate is signed by a certificate authority, the 93 // certFile should be the concatenation of the server's certificate followed by the 94 // CA's certificate. 95 // 96 // If srv.Addr is blank, ":https" is used. 97 func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { 98 addr := srv.Addr 99 if addr == "" { 100 addr = ":https" 101 } 102 103 if srv.TLSConfig == nil { 104 srv.TLSConfig = &tls.Config{} 105 } 106 if srv.TLSConfig.NextProtos == nil { 107 srv.TLSConfig.NextProtos = []string{"http/1.1"} 108 } 109 110 srv.TLSConfig.Certificates = make([]tls.Certificate, 1) 111 srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) 112 if err != nil { 113 return 114 } 115 116 go srv.handleSignals() 117 118 ln, err := srv.getListener(addr) 119 if err != nil { 120 log.Println(err) 121 return err 122 } 123 srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) 124 125 if srv.isChild { 126 process, err := os.FindProcess(os.Getppid()) 127 if err != nil { 128 log.Println(err) 129 return err 130 } 131 err = process.Signal(syscall.SIGTERM) 132 if err != nil { 133 return err 134 } 135 } 136 137 log.Println(os.Getpid(), srv.Addr) 138 return srv.Serve() 139 } 140 141 // ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls 142 // Serve to handle requests on incoming mutual TLS connections. 143 func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { 144 addr := srv.Addr 145 if addr == "" { 146 addr = ":https" 147 } 148 149 if srv.TLSConfig == nil { 150 srv.TLSConfig = &tls.Config{} 151 } 152 if srv.TLSConfig.NextProtos == nil { 153 srv.TLSConfig.NextProtos = []string{"http/1.1"} 154 } 155 156 srv.TLSConfig.Certificates = make([]tls.Certificate, 1) 157 srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) 158 if err != nil { 159 return 160 } 161 srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert 162 pool := x509.NewCertPool() 163 data, err := ioutil.ReadFile(trustFile) 164 if err != nil { 165 log.Println(err) 166 return err 167 } 168 pool.AppendCertsFromPEM(data) 169 srv.TLSConfig.ClientCAs = pool 170 log.Println("Mutual HTTPS") 171 go srv.handleSignals() 172 173 ln, err := srv.getListener(addr) 174 if err != nil { 175 log.Println(err) 176 return err 177 } 178 srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) 179 180 if srv.isChild { 181 process, err := os.FindProcess(os.Getppid()) 182 if err != nil { 183 log.Println(err) 184 return err 185 } 186 err = process.Signal(syscall.SIGTERM) 187 if err != nil { 188 return err 189 } 190 } 191 192 log.Println(os.Getpid(), srv.Addr) 193 return srv.Serve() 194 } 195 196 // getListener either opens a new socket to listen on, or takes the acceptor socket 197 // it got passed when restarted. 198 func (srv *Server) getListener(laddr string) (l net.Listener, err error) { 199 if srv.isChild { 200 var ptrOffset uint 201 if len(socketPtrOffsetMap) > 0 { 202 ptrOffset = socketPtrOffsetMap[laddr] 203 log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) 204 } 205 206 f := os.NewFile(uintptr(3+ptrOffset), "") 207 l, err = net.FileListener(f) 208 if err != nil { 209 err = fmt.Errorf("net.FileListener error: %v", err) 210 return 211 } 212 } else { 213 l, err = net.Listen(srv.Network, laddr) 214 if err != nil { 215 err = fmt.Errorf("net.Listen error: %v", err) 216 return 217 } 218 } 219 return 220 } 221 222 type tcpKeepAliveListener struct { 223 *net.TCPListener 224 } 225 226 func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { 227 tc, err := ln.AcceptTCP() 228 if err != nil { 229 return 230 } 231 tc.SetKeepAlive(true) 232 tc.SetKeepAlivePeriod(3 * time.Minute) 233 return tc, nil 234 } 235 236 // handleSignals listens for os Signals and calls any hooked in function that the 237 // user had registered with the signal. 238 func (srv *Server) handleSignals() { 239 var sig os.Signal 240 241 signal.Notify( 242 srv.sigChan, 243 hookableSignals..., 244 ) 245 246 pid := syscall.Getpid() 247 for { 248 sig = <-srv.sigChan 249 srv.signalHooks(PreSignal, sig) 250 switch sig { 251 case syscall.SIGHUP: 252 log.Println(pid, "Received SIGHUP. forking.") 253 err := srv.fork() 254 if err != nil { 255 log.Println("Fork err:", err) 256 } 257 case syscall.SIGINT: 258 log.Println(pid, "Received SIGINT.") 259 srv.shutdown() 260 case syscall.SIGTERM: 261 log.Println(pid, "Received SIGTERM.") 262 srv.shutdown() 263 default: 264 log.Printf("Received %v: nothing i care about...\n", sig) 265 } 266 srv.signalHooks(PostSignal, sig) 267 } 268 } 269 270 func (srv *Server) signalHooks(ppFlag int, sig os.Signal) { 271 if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { 272 return 273 } 274 for _, f := range srv.SignalHooks[ppFlag][sig] { 275 f() 276 } 277 } 278 279 // shutdown closes the listener so that no new connections are accepted. it also 280 // starts a goroutine that will serverTimeout (stop all running requests) the server 281 // after DefaultTimeout. 282 func (srv *Server) shutdown() { 283 if srv.state != StateRunning { 284 return 285 } 286 287 srv.state = StateShuttingDown 288 log.Println(syscall.Getpid(), "Waiting for connections to finish...") 289 ctx := context.Background() 290 if DefaultTimeout >= 0 { 291 var cancel context.CancelFunc 292 ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) 293 defer cancel() 294 } 295 srv.terminalChan <- srv.Server.Shutdown(ctx) 296 } 297 298 func (srv *Server) fork() (err error) { 299 regLock.Lock() 300 defer regLock.Unlock() 301 if runningServersForked { 302 return 303 } 304 runningServersForked = true 305 306 var files = make([]*os.File, len(runningServers)) 307 var orderArgs = make([]string, len(runningServers)) 308 for _, srvPtr := range runningServers { 309 f, _ := srvPtr.ln.(*net.TCPListener).File() 310 files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f 311 orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr 312 } 313 314 log.Println(files) 315 path := os.Args[0] 316 var args []string 317 if len(os.Args) > 1 { 318 for _, arg := range os.Args[1:] { 319 if arg == "-graceful" { 320 break 321 } 322 args = append(args, arg) 323 } 324 } 325 args = append(args, "-graceful") 326 if len(runningServers) > 1 { 327 args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) 328 log.Println(args) 329 } 330 cmd := exec.Command(path, args...) 331 cmd.Stdout = os.Stdout 332 cmd.Stderr = os.Stderr 333 cmd.ExtraFiles = files 334 err = cmd.Start() 335 if err != nil { 336 log.Fatalf("Restart: Failed to launch, error: %v", err) 337 } 338 339 return 340 } 341 342 // RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. 343 func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { 344 if ppFlag != PreSignal && ppFlag != PostSignal { 345 err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal") 346 return 347 } 348 for _, s := range hookableSignals { 349 if s == sig { 350 srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f) 351 return 352 } 353 } 354 err = fmt.Errorf("Signal '%v' is not supported", sig) 355 return 356 }