github.com/gitbundle/modules@v0.0.0-20231025071548-85b91c5c3b01/graceful/server.go (about) 1 // Copyright 2023 The GitBundle Inc. All rights reserved. 2 // Copyright 2017 The Gitea Authors. All rights reserved. 3 // Use of this source code is governed by a MIT-style 4 // license that can be found in the LICENSE file. 5 6 // This code is highly inspired by endless go 7 8 package graceful 9 10 import ( 11 "crypto/tls" 12 "net" 13 "os" 14 "strings" 15 "sync" 16 "sync/atomic" 17 "syscall" 18 "time" 19 20 "github.com/gitbundle/modules/log" 21 "github.com/gitbundle/modules/setting" 22 ) 23 24 var ( 25 // DefaultReadTimeOut default read timeout 26 DefaultReadTimeOut time.Duration 27 // DefaultWriteTimeOut default write timeout 28 DefaultWriteTimeOut time.Duration 29 // DefaultMaxHeaderBytes default max header bytes 30 DefaultMaxHeaderBytes int 31 // PerWriteWriteTimeout timeout for writes 32 PerWriteWriteTimeout = 30 * time.Second 33 // PerWriteWriteTimeoutKbTime is a timeout taking account of how much there is to be written 34 PerWriteWriteTimeoutKbTime = 10 * time.Second 35 ) 36 37 func init() { 38 DefaultMaxHeaderBytes = 0 // use http.DefaultMaxHeaderBytes - which currently is 1 << 20 (1MB) 39 } 40 41 // ServeFunction represents a listen.Accept loop 42 type ServeFunction = func(net.Listener) error 43 44 // Server represents our graceful server 45 type Server struct { 46 network string 47 address string 48 listener net.Listener 49 wg sync.WaitGroup 50 state state 51 lock *sync.RWMutex 52 BeforeBegin func(network, address string) 53 OnShutdown func() 54 PerWriteTimeout time.Duration 55 PerWritePerKbTimeout time.Duration 56 } 57 58 // NewServer creates a server on network at provided address 59 func NewServer(network, address, name string) *Server { 60 if GetManager().IsChild() { 61 log.Info("Restarting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid()) 62 } else { 63 log.Info("Starting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid()) 64 } 65 srv := &Server{ 66 wg: sync.WaitGroup{}, 67 state: stateInit, 68 lock: &sync.RWMutex{}, 69 network: network, 70 address: address, 71 PerWriteTimeout: setting.PerWriteTimeout, 72 PerWritePerKbTimeout: setting.PerWritePerKbTimeout, 73 } 74 75 srv.BeforeBegin = func(network, addr string) { 76 log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid()) 77 } 78 79 return srv 80 } 81 82 // ListenAndServe listens on the provided network address and then calls Serve 83 // to handle requests on incoming connections. 84 func (srv *Server) ListenAndServe(serve ServeFunction) error { 85 go srv.awaitShutdown() 86 87 l, err := GetListener(srv.network, srv.address) 88 if err != nil { 89 log.Error("Unable to GetListener: %v", err) 90 return err 91 } 92 93 srv.listener = newWrappedListener(l, srv) 94 95 srv.BeforeBegin(srv.network, srv.address) 96 97 return srv.Serve(serve) 98 } 99 100 // ListenAndServeTLSConfig listens on the provided network address and then calls 101 // Serve to handle requests on incoming TLS connections. 102 func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction) error { 103 go srv.awaitShutdown() 104 105 if tlsConfig.MinVersion == 0 { 106 tlsConfig.MinVersion = tls.VersionTLS12 107 } 108 109 l, err := GetListener(srv.network, srv.address) 110 if err != nil { 111 log.Error("Unable to get Listener: %v", err) 112 return err 113 } 114 115 wl := newWrappedListener(l, srv) 116 srv.listener = tls.NewListener(wl, tlsConfig) 117 118 srv.BeforeBegin(srv.network, srv.address) 119 120 return srv.Serve(serve) 121 } 122 123 // Serve accepts incoming HTTP connections on the wrapped listener l, creating a new 124 // service goroutine for each. The service goroutines read requests and then call 125 // handler to reply to them. Handler is typically nil, in which case the 126 // DefaultServeMux is used. 127 // 128 // In addition to the standard Serve behaviour each connection is added to a 129 // sync.Waitgroup so that all outstanding connections can be served before shutting 130 // down the server. 131 func (srv *Server) Serve(serve ServeFunction) error { 132 defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid()) 133 srv.setState(stateRunning) 134 GetManager().RegisterServer() 135 err := serve(srv.listener) 136 log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid()) 137 srv.wg.Wait() 138 srv.setState(stateTerminate) 139 GetManager().ServerDone() 140 // use of closed means that the listeners are closed - i.e. we should be shutting down - return nil 141 if err == nil || strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "http: Server closed") { 142 return nil 143 } 144 return err 145 } 146 147 func (srv *Server) getState() state { 148 srv.lock.RLock() 149 defer srv.lock.RUnlock() 150 151 return srv.state 152 } 153 154 func (srv *Server) setState(st state) { 155 srv.lock.Lock() 156 defer srv.lock.Unlock() 157 158 srv.state = st 159 } 160 161 type filer interface { 162 File() (*os.File, error) 163 } 164 165 type wrappedListener struct { 166 net.Listener 167 stopped bool 168 server *Server 169 } 170 171 func newWrappedListener(l net.Listener, srv *Server) *wrappedListener { 172 return &wrappedListener{ 173 Listener: l, 174 server: srv, 175 } 176 } 177 178 func (wl *wrappedListener) Accept() (net.Conn, error) { 179 var c net.Conn 180 // Set keepalive on TCPListeners connections. 181 if tcl, ok := wl.Listener.(*net.TCPListener); ok { 182 tc, err := tcl.AcceptTCP() 183 if err != nil { 184 return nil, err 185 } 186 _ = tc.SetKeepAlive(true) // see http.tcpKeepAliveListener 187 _ = tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener 188 c = tc 189 } else { 190 var err error 191 c, err = wl.Listener.Accept() 192 if err != nil { 193 return nil, err 194 } 195 } 196 197 closed := int32(0) 198 199 c = &wrappedConn{ 200 Conn: c, 201 server: wl.server, 202 closed: &closed, 203 perWriteTimeout: wl.server.PerWriteTimeout, 204 perWritePerKbTimeout: wl.server.PerWritePerKbTimeout, 205 } 206 207 wl.server.wg.Add(1) 208 return c, nil 209 } 210 211 func (wl *wrappedListener) Close() error { 212 if wl.stopped { 213 return syscall.EINVAL 214 } 215 216 wl.stopped = true 217 return wl.Listener.Close() 218 } 219 220 func (wl *wrappedListener) File() (*os.File, error) { 221 // returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes 222 return wl.Listener.(filer).File() 223 } 224 225 type wrappedConn struct { 226 net.Conn 227 server *Server 228 closed *int32 229 deadline time.Time 230 perWriteTimeout time.Duration 231 perWritePerKbTimeout time.Duration 232 } 233 234 func (w *wrappedConn) Write(p []byte) (n int, err error) { 235 if w.perWriteTimeout > 0 { 236 minTimeout := time.Duration(len(p)/1024) * w.perWritePerKbTimeout 237 minDeadline := time.Now().Add(minTimeout).Add(w.perWriteTimeout) 238 239 w.deadline = w.deadline.Add(minTimeout) 240 if minDeadline.After(w.deadline) { 241 w.deadline = minDeadline 242 } 243 _ = w.Conn.SetWriteDeadline(w.deadline) 244 } 245 return w.Conn.Write(p) 246 } 247 248 func (w *wrappedConn) Close() error { 249 if atomic.CompareAndSwapInt32(w.closed, 0, 1) { 250 defer func() { 251 if err := recover(); err != nil { 252 select { 253 case <-GetManager().IsHammer(): 254 // Likely deadlocked request released at hammertime 255 log.Warn("Panic during connection close! %v. Likely there has been a deadlocked request which has been released by forced shutdown.", err) 256 default: 257 log.Error("Panic during connection close! %v", err) 258 } 259 } 260 }() 261 w.server.wg.Done() 262 } 263 return w.Conn.Close() 264 }