github.com/lingyao2333/mo-zero@v1.4.1/rest/handler/timeouthandler.go (about)

     1  package handler
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"path"
    11  	"runtime"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/lingyao2333/mo-zero/rest/httpx"
    17  	"github.com/lingyao2333/mo-zero/rest/internal"
    18  )
    19  
    20  const (
    21  	statusClientClosedRequest = 499
    22  	reason                    = "Request Timeout"
    23  	headerUpgrade             = "Upgrade"
    24  	valueWebsocket            = "websocket"
    25  )
    26  
    27  // TimeoutHandler returns the handler with given timeout.
    28  // If client closed request, code 499 will be logged.
    29  // Notice: even if canceled in server side, 499 will be logged as well.
    30  func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
    31  	return func(next http.Handler) http.Handler {
    32  		if duration > 0 {
    33  			return &timeoutHandler{
    34  				handler: next,
    35  				dt:      duration,
    36  			}
    37  		}
    38  
    39  		return next
    40  	}
    41  }
    42  
    43  // timeoutHandler is the handler that controls the request timeout.
    44  // Why we implement it on our own, because the stdlib implementation
    45  // treats the ClientClosedRequest as http.StatusServiceUnavailable.
    46  // And we write the codes in logs as code 499, which is defined by nginx.
    47  type timeoutHandler struct {
    48  	handler http.Handler
    49  	dt      time.Duration
    50  }
    51  
    52  func (h *timeoutHandler) errorBody() string {
    53  	return reason
    54  }
    55  
    56  func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    57  	if r.Header.Get(headerUpgrade) == valueWebsocket {
    58  		h.handler.ServeHTTP(w, r)
    59  		return
    60  	}
    61  
    62  	ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt)
    63  	defer cancelCtx()
    64  
    65  	r = r.WithContext(ctx)
    66  	done := make(chan struct{})
    67  	tw := &timeoutWriter{
    68  		w:   w,
    69  		h:   make(http.Header),
    70  		req: r,
    71  	}
    72  	panicChan := make(chan interface{}, 1)
    73  	go func() {
    74  		defer func() {
    75  			if p := recover(); p != nil {
    76  				panicChan <- p
    77  			}
    78  		}()
    79  		h.handler.ServeHTTP(tw, r)
    80  		close(done)
    81  	}()
    82  	select {
    83  	case p := <-panicChan:
    84  		panic(p)
    85  	case <-done:
    86  		tw.mu.Lock()
    87  		defer tw.mu.Unlock()
    88  		dst := w.Header()
    89  		for k, vv := range tw.h {
    90  			dst[k] = vv
    91  		}
    92  		if !tw.wroteHeader {
    93  			tw.code = http.StatusOK
    94  		}
    95  		w.WriteHeader(tw.code)
    96  		w.Write(tw.wbuf.Bytes())
    97  	case <-ctx.Done():
    98  		tw.mu.Lock()
    99  		defer tw.mu.Unlock()
   100  		// there isn't any user-defined middleware before TimoutHandler,
   101  		// so we can guarantee that cancelation in biz related code won't come here.
   102  		httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) {
   103  			if errors.Is(err, context.Canceled) {
   104  				w.WriteHeader(statusClientClosedRequest)
   105  			} else {
   106  				w.WriteHeader(http.StatusServiceUnavailable)
   107  			}
   108  			io.WriteString(w, h.errorBody())
   109  		})
   110  		tw.timedOut = true
   111  	}
   112  }
   113  
   114  type timeoutWriter struct {
   115  	w    http.ResponseWriter
   116  	h    http.Header
   117  	wbuf bytes.Buffer
   118  	req  *http.Request
   119  
   120  	mu          sync.Mutex
   121  	timedOut    bool
   122  	wroteHeader bool
   123  	code        int
   124  }
   125  
   126  var _ http.Pusher = (*timeoutWriter)(nil)
   127  
   128  // Header returns the underline temporary http.Header.
   129  func (tw *timeoutWriter) Header() http.Header { return tw.h }
   130  
   131  // Push implements the Pusher interface.
   132  func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error {
   133  	if pusher, ok := tw.w.(http.Pusher); ok {
   134  		return pusher.Push(target, opts)
   135  	}
   136  	return http.ErrNotSupported
   137  }
   138  
   139  // Write writes the data to the connection as part of an HTTP reply.
   140  // Timeout and multiple header written are guarded.
   141  func (tw *timeoutWriter) Write(p []byte) (int, error) {
   142  	tw.mu.Lock()
   143  	defer tw.mu.Unlock()
   144  
   145  	if tw.timedOut {
   146  		return 0, http.ErrHandlerTimeout
   147  	}
   148  
   149  	if !tw.wroteHeader {
   150  		tw.writeHeaderLocked(http.StatusOK)
   151  	}
   152  	return tw.wbuf.Write(p)
   153  }
   154  
   155  func (tw *timeoutWriter) writeHeaderLocked(code int) {
   156  	checkWriteHeaderCode(code)
   157  
   158  	switch {
   159  	case tw.timedOut:
   160  		return
   161  	case tw.wroteHeader:
   162  		if tw.req != nil {
   163  			caller := relevantCaller()
   164  			internal.Errorf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)",
   165  				caller.Function, path.Base(caller.File), caller.Line)
   166  		}
   167  	default:
   168  		tw.wroteHeader = true
   169  		tw.code = code
   170  	}
   171  }
   172  
   173  func (tw *timeoutWriter) WriteHeader(code int) {
   174  	tw.mu.Lock()
   175  	defer tw.mu.Unlock()
   176  	tw.writeHeaderLocked(code)
   177  }
   178  
   179  func checkWriteHeaderCode(code int) {
   180  	if code < 100 || code > 599 {
   181  		panic(fmt.Sprintf("invalid WriteHeader code %v", code))
   182  	}
   183  }
   184  
   185  // relevantCaller searches the call stack for the first function outside of net/http.
   186  // The purpose of this function is to provide more helpful error messages.
   187  func relevantCaller() runtime.Frame {
   188  	pc := make([]uintptr, 16)
   189  	n := runtime.Callers(1, pc)
   190  	frames := runtime.CallersFrames(pc[:n])
   191  	var frame runtime.Frame
   192  	for {
   193  		frame, more := frames.Next()
   194  		if !strings.HasPrefix(frame.Function, "net/http.") {
   195  			return frame
   196  		}
   197  		if !more {
   198  			break
   199  		}
   200  	}
   201  	return frame
   202  }