github.com/gofunct/common@v0.0.0-20190131174352-fd058c7fbf22/pkg/transport/handlers/handlers.go (about)

     1  package handlers
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"sort"
     9  	"strings"
    10  )
    11  
    12  // MethodHandler is an http.Handler that dispatches to a handler whose key in the
    13  // MethodHandler's map matches the name of the HTTP request's method, eg: GET
    14  //
    15  // If the request's method is OPTIONS and OPTIONS is not a key in the map then
    16  // the handler responds with a status of 200 and sets the Allow header to a
    17  // comma-separated list of available methods.
    18  //
    19  // If the request's method doesn't match any of its keys the handler responds
    20  // with a status of HTTP 405 "Method Not Allowed" and sets the Allow header to a
    21  // comma-separated list of available methods.
    22  type MethodHandler map[string]http.Handler
    23  
    24  func (h MethodHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    25  	if handler, ok := h[req.Method]; ok {
    26  		handler.ServeHTTP(w, req)
    27  	} else {
    28  		allow := []string{}
    29  		for k := range h {
    30  			allow = append(allow, k)
    31  		}
    32  		sort.Strings(allow)
    33  		w.Header().Set("Allow", strings.Join(allow, ", "))
    34  		if req.Method == "OPTIONS" {
    35  			w.WriteHeader(http.StatusOK)
    36  		} else {
    37  			http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
    38  		}
    39  	}
    40  }
    41  
    42  // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP
    43  // status code and body size
    44  type responseLogger struct {
    45  	w      http.ResponseWriter
    46  	status int
    47  	size   int
    48  }
    49  
    50  func (l *responseLogger) Header() http.Header {
    51  	return l.w.Header()
    52  }
    53  
    54  func (l *responseLogger) Write(b []byte) (int, error) {
    55  	size, err := l.w.Write(b)
    56  	l.size += size
    57  	return size, err
    58  }
    59  
    60  func (l *responseLogger) WriteHeader(s int) {
    61  	l.w.WriteHeader(s)
    62  	l.status = s
    63  }
    64  
    65  func (l *responseLogger) Status() int {
    66  	return l.status
    67  }
    68  
    69  func (l *responseLogger) Size() int {
    70  	return l.size
    71  }
    72  
    73  func (l *responseLogger) Flush() {
    74  	f, ok := l.w.(http.Flusher)
    75  	if ok {
    76  		f.Flush()
    77  	}
    78  }
    79  
    80  type hijackLogger struct {
    81  	responseLogger
    82  }
    83  
    84  func (l *hijackLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
    85  	h := l.responseLogger.w.(http.Hijacker)
    86  	conn, rw, err := h.Hijack()
    87  	if err == nil && l.responseLogger.status == 0 {
    88  		// The status will be StatusSwitchingProtocols if there was no error and
    89  		// WriteHeader has not been called yet
    90  		l.responseLogger.status = http.StatusSwitchingProtocols
    91  	}
    92  	return conn, rw, err
    93  }
    94  
    95  type closeNotifyWriter struct {
    96  	loggingResponseWriter
    97  	http.CloseNotifier
    98  }
    99  
   100  type hijackCloseNotifier struct {
   101  	loggingResponseWriter
   102  	http.Hijacker
   103  	http.CloseNotifier
   104  }
   105  
   106  // isContentType validates the Content-Type header matches the supplied
   107  // contentType. That is, its type and subtype match.
   108  func isContentType(h http.Header, contentType string) bool {
   109  	ct := h.Get("Content-Type")
   110  	if i := strings.IndexRune(ct, ';'); i != -1 {
   111  		ct = ct[0:i]
   112  	}
   113  	return ct == contentType
   114  }
   115  
   116  // ContentTypeHandler wraps and returns a http.Handler, validating the request
   117  // content type is compatible with the contentTypes list. It writes a HTTP 415
   118  // error if that fails.
   119  //
   120  // Only PUT, POST, and PATCH requests are considered.
   121  func ContentTypeHandler(h http.Handler, contentTypes ...string) http.Handler {
   122  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   123  		if !(r.Method == "PUT" || r.Method == "POST" || r.Method == "PATCH") {
   124  			h.ServeHTTP(w, r)
   125  			return
   126  		}
   127  
   128  		for _, ct := range contentTypes {
   129  			if isContentType(r.Header, ct) {
   130  				h.ServeHTTP(w, r)
   131  				return
   132  			}
   133  		}
   134  		http.Error(w, fmt.Sprintf("Unsupported content type %q; expected one of %q", r.Header.Get("Content-Type"), contentTypes), http.StatusUnsupportedMediaType)
   135  	})
   136  }
   137  
   138  const (
   139  	// HTTPMethodOverrideHeader is a commonly used
   140  	// http header to override a request method.
   141  	HTTPMethodOverrideHeader = "X-HTTP-Method-Override"
   142  	// HTTPMethodOverrideFormKey is a commonly used
   143  	// HTML form key to override a request method.
   144  	HTTPMethodOverrideFormKey = "_method"
   145  )
   146  
   147  // HTTPMethodOverrideHandler wraps and returns a http.Handler which checks for
   148  // the X-HTTP-Method-Override header or the _method form key, and overrides (if
   149  // valid) request.Method with its value.
   150  //
   151  // This is especially useful for HTTP clients that don't support many http verbs.
   152  // It isn't secure to override e.g a GET to a POST, so only POST requests are
   153  // considered.  Likewise, the override method can only be a "write" method: PUT,
   154  // PATCH or DELETE.
   155  //
   156  // Form method takes precedence over header method.
   157  func HTTPMethodOverrideHandler(h http.Handler) http.Handler {
   158  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   159  		if r.Method == "POST" {
   160  			om := r.FormValue(HTTPMethodOverrideFormKey)
   161  			if om == "" {
   162  				om = r.Header.Get(HTTPMethodOverrideHeader)
   163  			}
   164  			if om == "PUT" || om == "PATCH" || om == "DELETE" {
   165  				r.Method = om
   166  			}
   167  		}
   168  		h.ServeHTTP(w, r)
   169  	})
   170  }