github.com/lingyao2333/mo-zero@v1.4.1/rest/handler/timeouthandler.go (about) 1 package handler 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "fmt" 8 "io" 9 "net/http" 10 "path" 11 "runtime" 12 "strings" 13 "sync" 14 "time" 15 16 "github.com/lingyao2333/mo-zero/rest/httpx" 17 "github.com/lingyao2333/mo-zero/rest/internal" 18 ) 19 20 const ( 21 statusClientClosedRequest = 499 22 reason = "Request Timeout" 23 headerUpgrade = "Upgrade" 24 valueWebsocket = "websocket" 25 ) 26 27 // TimeoutHandler returns the handler with given timeout. 28 // If client closed request, code 499 will be logged. 29 // Notice: even if canceled in server side, 499 will be logged as well. 30 func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler { 31 return func(next http.Handler) http.Handler { 32 if duration > 0 { 33 return &timeoutHandler{ 34 handler: next, 35 dt: duration, 36 } 37 } 38 39 return next 40 } 41 } 42 43 // timeoutHandler is the handler that controls the request timeout. 44 // Why we implement it on our own, because the stdlib implementation 45 // treats the ClientClosedRequest as http.StatusServiceUnavailable. 46 // And we write the codes in logs as code 499, which is defined by nginx. 47 type timeoutHandler struct { 48 handler http.Handler 49 dt time.Duration 50 } 51 52 func (h *timeoutHandler) errorBody() string { 53 return reason 54 } 55 56 func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 57 if r.Header.Get(headerUpgrade) == valueWebsocket { 58 h.handler.ServeHTTP(w, r) 59 return 60 } 61 62 ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt) 63 defer cancelCtx() 64 65 r = r.WithContext(ctx) 66 done := make(chan struct{}) 67 tw := &timeoutWriter{ 68 w: w, 69 h: make(http.Header), 70 req: r, 71 } 72 panicChan := make(chan interface{}, 1) 73 go func() { 74 defer func() { 75 if p := recover(); p != nil { 76 panicChan <- p 77 } 78 }() 79 h.handler.ServeHTTP(tw, r) 80 close(done) 81 }() 82 select { 83 case p := <-panicChan: 84 panic(p) 85 case <-done: 86 tw.mu.Lock() 87 defer tw.mu.Unlock() 88 dst := w.Header() 89 for k, vv := range tw.h { 90 dst[k] = vv 91 } 92 if !tw.wroteHeader { 93 tw.code = http.StatusOK 94 } 95 w.WriteHeader(tw.code) 96 w.Write(tw.wbuf.Bytes()) 97 case <-ctx.Done(): 98 tw.mu.Lock() 99 defer tw.mu.Unlock() 100 // there isn't any user-defined middleware before TimoutHandler, 101 // so we can guarantee that cancelation in biz related code won't come here. 102 httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) { 103 if errors.Is(err, context.Canceled) { 104 w.WriteHeader(statusClientClosedRequest) 105 } else { 106 w.WriteHeader(http.StatusServiceUnavailable) 107 } 108 io.WriteString(w, h.errorBody()) 109 }) 110 tw.timedOut = true 111 } 112 } 113 114 type timeoutWriter struct { 115 w http.ResponseWriter 116 h http.Header 117 wbuf bytes.Buffer 118 req *http.Request 119 120 mu sync.Mutex 121 timedOut bool 122 wroteHeader bool 123 code int 124 } 125 126 var _ http.Pusher = (*timeoutWriter)(nil) 127 128 // Header returns the underline temporary http.Header. 129 func (tw *timeoutWriter) Header() http.Header { return tw.h } 130 131 // Push implements the Pusher interface. 132 func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error { 133 if pusher, ok := tw.w.(http.Pusher); ok { 134 return pusher.Push(target, opts) 135 } 136 return http.ErrNotSupported 137 } 138 139 // Write writes the data to the connection as part of an HTTP reply. 140 // Timeout and multiple header written are guarded. 141 func (tw *timeoutWriter) Write(p []byte) (int, error) { 142 tw.mu.Lock() 143 defer tw.mu.Unlock() 144 145 if tw.timedOut { 146 return 0, http.ErrHandlerTimeout 147 } 148 149 if !tw.wroteHeader { 150 tw.writeHeaderLocked(http.StatusOK) 151 } 152 return tw.wbuf.Write(p) 153 } 154 155 func (tw *timeoutWriter) writeHeaderLocked(code int) { 156 checkWriteHeaderCode(code) 157 158 switch { 159 case tw.timedOut: 160 return 161 case tw.wroteHeader: 162 if tw.req != nil { 163 caller := relevantCaller() 164 internal.Errorf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)", 165 caller.Function, path.Base(caller.File), caller.Line) 166 } 167 default: 168 tw.wroteHeader = true 169 tw.code = code 170 } 171 } 172 173 func (tw *timeoutWriter) WriteHeader(code int) { 174 tw.mu.Lock() 175 defer tw.mu.Unlock() 176 tw.writeHeaderLocked(code) 177 } 178 179 func checkWriteHeaderCode(code int) { 180 if code < 100 || code > 599 { 181 panic(fmt.Sprintf("invalid WriteHeader code %v", code)) 182 } 183 } 184 185 // relevantCaller searches the call stack for the first function outside of net/http. 186 // The purpose of this function is to provide more helpful error messages. 187 func relevantCaller() runtime.Frame { 188 pc := make([]uintptr, 16) 189 n := runtime.Callers(1, pc) 190 frames := runtime.CallersFrames(pc[:n]) 191 var frame runtime.Frame 192 for { 193 frame, more := frames.Next() 194 if !strings.HasPrefix(frame.Function, "net/http.") { 195 return frame 196 } 197 if !more { 198 break 199 } 200 } 201 return frame 202 }