github.com/weaveworks/common@v0.0.0-20230728070032-dd9e68f319d5/middleware/errorhandler.go (about) 1 package middleware 2 3 import ( 4 "bufio" 5 "fmt" 6 "net" 7 "net/http" 8 ) 9 10 func copyHeaders(src, dest http.Header) { 11 for k, v := range src { 12 dest[k] = v 13 } 14 } 15 16 // ErrorHandler lets you call an alternate http handler upon a certain response code. 17 // Note it will assume a 200 if the wrapped handler does not write anything 18 type ErrorHandler struct { 19 Code int 20 Handler http.Handler 21 } 22 23 // Wrap implements Middleware 24 func (e ErrorHandler) Wrap(next http.Handler) http.Handler { 25 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 26 i := newErrorInterceptor(w, e.Code) 27 next.ServeHTTP(i, r) 28 if !i.gotCode { 29 i.WriteHeader(http.StatusOK) 30 } 31 if i.intercepted { 32 e.Handler.ServeHTTP(w, r) 33 } 34 }) 35 } 36 37 // errorInterceptor wraps an underlying ResponseWriter and buffers all header changes, until it knows the return code. 38 // It then passes everything through, unless the code matches the target code, in which case it will discard everything. 39 type errorInterceptor struct { 40 originalWriter http.ResponseWriter 41 targetCode int 42 headers http.Header 43 gotCode bool 44 intercepted bool 45 } 46 47 func newErrorInterceptor(w http.ResponseWriter, code int) *errorInterceptor { 48 i := errorInterceptor{originalWriter: w, targetCode: code} 49 i.headers = make(http.Header) 50 copyHeaders(w.Header(), i.headers) 51 return &i 52 } 53 54 // Unwrap method is used by http.ResponseController to get access to original http.ResponseWriter. 55 func (i *errorInterceptor) Unwrap() http.ResponseWriter { 56 return i.originalWriter 57 } 58 59 // Header implements http.ResponseWriter 60 func (i *errorInterceptor) Header() http.Header { 61 return i.headers 62 } 63 64 // WriteHeader implements http.ResponseWriter 65 func (i *errorInterceptor) WriteHeader(code int) { 66 if i.gotCode { 67 panic("errorInterceptor.WriteHeader() called twice") 68 } 69 70 i.gotCode = true 71 if code == i.targetCode { 72 i.intercepted = true 73 } else { 74 copyHeaders(i.headers, i.originalWriter.Header()) 75 i.originalWriter.WriteHeader(code) 76 } 77 } 78 79 // Write implements http.ResponseWriter 80 func (i *errorInterceptor) Write(data []byte) (int, error) { 81 if !i.gotCode { 82 i.WriteHeader(http.StatusOK) 83 } 84 if !i.intercepted { 85 return i.originalWriter.Write(data) 86 } 87 return len(data), nil 88 } 89 90 // errorInterceptor also implements net.Hijacker, to let the downstream Handler 91 // hijack the connection. This is needed, for example, for working with websockets. 92 func (i *errorInterceptor) Hijack() (net.Conn, *bufio.ReadWriter, error) { 93 hj, ok := i.originalWriter.(http.Hijacker) 94 if !ok { 95 return nil, nil, fmt.Errorf("error interceptor: can't cast original ResponseWriter to Hijacker") 96 } 97 i.gotCode = true 98 return hj.Hijack() 99 }