code.gitea.io/gitea@v1.19.3/modules/graceful/server.go (about) 1 // Copyright 2019 The Gitea Authors. All rights reserved. 2 // SPDX-License-Identifier: MIT 3 4 // This code is highly inspired by endless go 5 6 package graceful 7 8 import ( 9 "crypto/tls" 10 "net" 11 "os" 12 "strings" 13 "sync" 14 "sync/atomic" 15 "syscall" 16 "time" 17 18 "code.gitea.io/gitea/modules/log" 19 "code.gitea.io/gitea/modules/proxyprotocol" 20 "code.gitea.io/gitea/modules/setting" 21 ) 22 23 var ( 24 // DefaultReadTimeOut default read timeout 25 DefaultReadTimeOut time.Duration 26 // DefaultWriteTimeOut default write timeout 27 DefaultWriteTimeOut time.Duration 28 // DefaultMaxHeaderBytes default max header bytes 29 DefaultMaxHeaderBytes int 30 // PerWriteWriteTimeout timeout for writes 31 PerWriteWriteTimeout = 30 * time.Second 32 // PerWriteWriteTimeoutKbTime is a timeout taking account of how much there is to be written 33 PerWriteWriteTimeoutKbTime = 10 * time.Second 34 ) 35 36 func init() { 37 DefaultMaxHeaderBytes = 0 // use http.DefaultMaxHeaderBytes - which currently is 1 << 20 (1MB) 38 } 39 40 // ServeFunction represents a listen.Accept loop 41 type ServeFunction = func(net.Listener) error 42 43 // Server represents our graceful server 44 type Server struct { 45 network string 46 address string 47 listener net.Listener 48 wg sync.WaitGroup 49 state state 50 lock *sync.RWMutex 51 BeforeBegin func(network, address string) 52 OnShutdown func() 53 PerWriteTimeout time.Duration 54 PerWritePerKbTimeout time.Duration 55 } 56 57 // NewServer creates a server on network at provided address 58 func NewServer(network, address, name string) *Server { 59 if GetManager().IsChild() { 60 log.Info("Restarting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid()) 61 } else { 62 log.Info("Starting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid()) 63 } 64 srv := &Server{ 65 wg: sync.WaitGroup{}, 66 state: stateInit, 67 lock: &sync.RWMutex{}, 68 network: network, 69 address: address, 70 PerWriteTimeout: setting.PerWriteTimeout, 71 PerWritePerKbTimeout: setting.PerWritePerKbTimeout, 72 } 73 74 srv.BeforeBegin = func(network, addr string) { 75 log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid()) 76 } 77 78 return srv 79 } 80 81 // ListenAndServe listens on the provided network address and then calls Serve 82 // to handle requests on incoming connections. 83 func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) error { 84 go srv.awaitShutdown() 85 86 listener, err := GetListener(srv.network, srv.address) 87 if err != nil { 88 log.Error("Unable to GetListener: %v", err) 89 return err 90 } 91 92 // we need to wrap the listener to take account of our lifecycle 93 listener = newWrappedListener(listener, srv) 94 95 // Now we need to take account of ProxyProtocol settings... 96 if useProxyProtocol { 97 listener = &proxyprotocol.Listener{ 98 Listener: listener, 99 ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, 100 AcceptUnknown: setting.ProxyProtocolAcceptUnknown, 101 } 102 } 103 srv.listener = listener 104 105 srv.BeforeBegin(srv.network, srv.address) 106 107 return srv.Serve(serve) 108 } 109 110 // ListenAndServeTLSConfig listens on the provided network address and then calls 111 // Serve to handle requests on incoming TLS connections. 112 func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction, useProxyProtocol, proxyProtocolTLSBridging bool) error { 113 go srv.awaitShutdown() 114 115 if tlsConfig.MinVersion == 0 { 116 tlsConfig.MinVersion = tls.VersionTLS12 117 } 118 119 listener, err := GetListener(srv.network, srv.address) 120 if err != nil { 121 log.Error("Unable to get Listener: %v", err) 122 return err 123 } 124 125 // we need to wrap the listener to take account of our lifecycle 126 listener = newWrappedListener(listener, srv) 127 128 // Now we need to take account of ProxyProtocol settings... If we're not bridging then we expect that the proxy will forward the connection to us 129 if useProxyProtocol && !proxyProtocolTLSBridging { 130 listener = &proxyprotocol.Listener{ 131 Listener: listener, 132 ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, 133 AcceptUnknown: setting.ProxyProtocolAcceptUnknown, 134 } 135 } 136 137 // Now handle the tls protocol 138 listener = tls.NewListener(listener, tlsConfig) 139 140 // Now if we're bridging then we need the proxy to tell us who we're bridging for... 141 if useProxyProtocol && proxyProtocolTLSBridging { 142 listener = &proxyprotocol.Listener{ 143 Listener: listener, 144 ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, 145 AcceptUnknown: setting.ProxyProtocolAcceptUnknown, 146 } 147 } 148 149 srv.listener = listener 150 srv.BeforeBegin(srv.network, srv.address) 151 152 return srv.Serve(serve) 153 } 154 155 // Serve accepts incoming HTTP connections on the wrapped listener l, creating a new 156 // service goroutine for each. The service goroutines read requests and then call 157 // handler to reply to them. Handler is typically nil, in which case the 158 // DefaultServeMux is used. 159 // 160 // In addition to the standard Serve behaviour each connection is added to a 161 // sync.Waitgroup so that all outstanding connections can be served before shutting 162 // down the server. 163 func (srv *Server) Serve(serve ServeFunction) error { 164 defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid()) 165 srv.setState(stateRunning) 166 GetManager().RegisterServer() 167 err := serve(srv.listener) 168 log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid()) 169 srv.wg.Wait() 170 srv.setState(stateTerminate) 171 GetManager().ServerDone() 172 // use of closed means that the listeners are closed - i.e. we should be shutting down - return nil 173 if err == nil || strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "http: Server closed") { 174 return nil 175 } 176 return err 177 } 178 179 func (srv *Server) getState() state { 180 srv.lock.RLock() 181 defer srv.lock.RUnlock() 182 183 return srv.state 184 } 185 186 func (srv *Server) setState(st state) { 187 srv.lock.Lock() 188 defer srv.lock.Unlock() 189 190 srv.state = st 191 } 192 193 type filer interface { 194 File() (*os.File, error) 195 } 196 197 type wrappedListener struct { 198 net.Listener 199 stopped bool 200 server *Server 201 } 202 203 func newWrappedListener(l net.Listener, srv *Server) *wrappedListener { 204 return &wrappedListener{ 205 Listener: l, 206 server: srv, 207 } 208 } 209 210 func (wl *wrappedListener) Accept() (net.Conn, error) { 211 var c net.Conn 212 // Set keepalive on TCPListeners connections. 213 if tcl, ok := wl.Listener.(*net.TCPListener); ok { 214 tc, err := tcl.AcceptTCP() 215 if err != nil { 216 return nil, err 217 } 218 _ = tc.SetKeepAlive(true) // see http.tcpKeepAliveListener 219 _ = tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener 220 c = tc 221 } else { 222 var err error 223 c, err = wl.Listener.Accept() 224 if err != nil { 225 return nil, err 226 } 227 } 228 229 closed := int32(0) 230 231 c = &wrappedConn{ 232 Conn: c, 233 server: wl.server, 234 closed: &closed, 235 perWriteTimeout: wl.server.PerWriteTimeout, 236 perWritePerKbTimeout: wl.server.PerWritePerKbTimeout, 237 } 238 239 wl.server.wg.Add(1) 240 return c, nil 241 } 242 243 func (wl *wrappedListener) Close() error { 244 if wl.stopped { 245 return syscall.EINVAL 246 } 247 248 wl.stopped = true 249 return wl.Listener.Close() 250 } 251 252 func (wl *wrappedListener) File() (*os.File, error) { 253 // returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes 254 return wl.Listener.(filer).File() 255 } 256 257 type wrappedConn struct { 258 net.Conn 259 server *Server 260 closed *int32 261 deadline time.Time 262 perWriteTimeout time.Duration 263 perWritePerKbTimeout time.Duration 264 } 265 266 func (w *wrappedConn) Write(p []byte) (n int, err error) { 267 if w.perWriteTimeout > 0 { 268 minTimeout := time.Duration(len(p)/1024) * w.perWritePerKbTimeout 269 minDeadline := time.Now().Add(minTimeout).Add(w.perWriteTimeout) 270 271 w.deadline = w.deadline.Add(minTimeout) 272 if minDeadline.After(w.deadline) { 273 w.deadline = minDeadline 274 } 275 _ = w.Conn.SetWriteDeadline(w.deadline) 276 } 277 return w.Conn.Write(p) 278 } 279 280 func (w *wrappedConn) Close() error { 281 if atomic.CompareAndSwapInt32(w.closed, 0, 1) { 282 defer func() { 283 if err := recover(); err != nil { 284 select { 285 case <-GetManager().IsHammer(): 286 // Likely deadlocked request released at hammertime 287 log.Warn("Panic during connection close! %v. Likely there has been a deadlocked request which has been released by forced shutdown.", err) 288 default: 289 log.Error("Panic during connection close! %v", err) 290 } 291 } 292 }() 293 w.server.wg.Done() 294 } 295 return w.Conn.Close() 296 }