github.com/weaveworks/common@v0.0.0-20230728070032-dd9e68f319d5/middleware/errorhandler.go (about)

     1  package middleware
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  )
     9  
    10  func copyHeaders(src, dest http.Header) {
    11  	for k, v := range src {
    12  		dest[k] = v
    13  	}
    14  }
    15  
    16  // ErrorHandler lets you call an alternate http handler upon a certain response code.
    17  // Note it will assume a 200 if the wrapped handler does not write anything
    18  type ErrorHandler struct {
    19  	Code    int
    20  	Handler http.Handler
    21  }
    22  
    23  // Wrap implements Middleware
    24  func (e ErrorHandler) Wrap(next http.Handler) http.Handler {
    25  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    26  		i := newErrorInterceptor(w, e.Code)
    27  		next.ServeHTTP(i, r)
    28  		if !i.gotCode {
    29  			i.WriteHeader(http.StatusOK)
    30  		}
    31  		if i.intercepted {
    32  			e.Handler.ServeHTTP(w, r)
    33  		}
    34  	})
    35  }
    36  
    37  // errorInterceptor wraps an underlying ResponseWriter and buffers all header changes, until it knows the return code.
    38  // It then passes everything through, unless the code matches the target code, in which case it will discard everything.
    39  type errorInterceptor struct {
    40  	originalWriter http.ResponseWriter
    41  	targetCode     int
    42  	headers        http.Header
    43  	gotCode        bool
    44  	intercepted    bool
    45  }
    46  
    47  func newErrorInterceptor(w http.ResponseWriter, code int) *errorInterceptor {
    48  	i := errorInterceptor{originalWriter: w, targetCode: code}
    49  	i.headers = make(http.Header)
    50  	copyHeaders(w.Header(), i.headers)
    51  	return &i
    52  }
    53  
    54  // Unwrap method is used by http.ResponseController to get access to original http.ResponseWriter.
    55  func (i *errorInterceptor) Unwrap() http.ResponseWriter {
    56  	return i.originalWriter
    57  }
    58  
    59  // Header implements http.ResponseWriter
    60  func (i *errorInterceptor) Header() http.Header {
    61  	return i.headers
    62  }
    63  
    64  // WriteHeader implements http.ResponseWriter
    65  func (i *errorInterceptor) WriteHeader(code int) {
    66  	if i.gotCode {
    67  		panic("errorInterceptor.WriteHeader() called twice")
    68  	}
    69  
    70  	i.gotCode = true
    71  	if code == i.targetCode {
    72  		i.intercepted = true
    73  	} else {
    74  		copyHeaders(i.headers, i.originalWriter.Header())
    75  		i.originalWriter.WriteHeader(code)
    76  	}
    77  }
    78  
    79  // Write implements http.ResponseWriter
    80  func (i *errorInterceptor) Write(data []byte) (int, error) {
    81  	if !i.gotCode {
    82  		i.WriteHeader(http.StatusOK)
    83  	}
    84  	if !i.intercepted {
    85  		return i.originalWriter.Write(data)
    86  	}
    87  	return len(data), nil
    88  }
    89  
    90  // errorInterceptor also implements net.Hijacker, to let the downstream Handler
    91  // hijack the connection. This is needed, for example, for working with websockets.
    92  func (i *errorInterceptor) Hijack() (net.Conn, *bufio.ReadWriter, error) {
    93  	hj, ok := i.originalWriter.(http.Hijacker)
    94  	if !ok {
    95  		return nil, nil, fmt.Errorf("error interceptor: can't cast original ResponseWriter to Hijacker")
    96  	}
    97  	i.gotCode = true
    98  	return hj.Hijack()
    99  }