github.com/mvg-fi/go-limiter@v0.1.1/httplimit/middleware.go (about) 1 // Package httplimit provides middleware for rate limiting HTTP handlers. 2 // 3 // The implementation is designed to work with Go's built-in http.Handler and 4 // http.HandlerFunc interfaces, so it will also work with any popular web 5 // frameworks that support middleware with these properties. 6 package httplimit 7 8 import ( 9 "fmt" 10 "net" 11 "net/http" 12 "strconv" 13 "time" 14 15 "github.com/mvg-fi/go-limiter" 16 ) 17 18 const ( 19 // HeaderRateLimitLimit, HeaderRateLimitRemaining, and HeaderRateLimitReset 20 // are the recommended return header values from IETF on rate limiting. Reset 21 // is in UTC time. 22 HeaderRateLimitLimit = "X-RateLimit-Limit" 23 HeaderRateLimitRemaining = "X-RateLimit-Remaining" 24 HeaderRateLimitReset = "X-RateLimit-Reset" 25 26 // HeaderRetryAfter is the header used to indicate when a client should retry 27 // requests (when the rate limit expires), in UTC time. 28 HeaderRetryAfter = "Retry-After" 29 ) 30 31 // KeyFunc is a function that accepts an http request and returns a string key 32 // that uniquely identifies this request for the purpose of rate limiting. 33 // 34 // KeyFuncs are called on each request, so be mindful of performance and 35 // implement caching where possible. If a KeyFunc returns an error, the HTTP 36 // handler will return Internal Server Error and will NOT take from the limiter 37 // store. 38 type KeyFunc func(r *http.Request) (string, error) 39 40 // IPKeyFunc returns a function that keys data based on the incoming requests IP 41 // address. By default this uses the RemoteAddr, but you can also specify a list 42 // of headers which will be checked for an IP address first (e.g. 43 // "X-Forwarded-For"). Headers are retrieved using Header.Get(), which means 44 // they are case insensitive. 45 func IPKeyFunc(headers ...string) KeyFunc { 46 return func(r *http.Request) (string, error) { 47 for _, h := range headers { 48 if v := r.Header.Get(h); v != "" { 49 return v, nil 50 } 51 } 52 53 ip, _, err := net.SplitHostPort(r.RemoteAddr) 54 if err != nil { 55 return "", err 56 } 57 return ip, nil 58 } 59 } 60 61 // Middleware is a handler/mux that can wrap other middlware to implement HTTP 62 // rate limiting. It can rate limit based on an arbitrary KeyFunc, and supports 63 // anything that implements limiter.StoreWithContext. 64 type Middleware struct { 65 store limiter.Store 66 keyFunc KeyFunc 67 } 68 69 // NewMiddleware creates a new middleware suitable for use as an HTTP handler. 70 // This function returns an error if either the Store or KeyFunc are nil. 71 func NewMiddleware(s limiter.Store, f KeyFunc) (*Middleware, error) { 72 if s == nil { 73 return nil, fmt.Errorf("store cannot be nil") 74 } 75 76 if f == nil { 77 return nil, fmt.Errorf("key function cannot be nil") 78 } 79 80 return &Middleware{ 81 store: s, 82 keyFunc: f, 83 }, nil 84 } 85 86 // Handle returns the HTTP handler as a middleware. This handler calls Take() on 87 // the store and sets the common rate limiting headers. If the take is 88 // successful, the remaining middleware is called. If take is unsuccessful, the 89 // middleware chain is halted and the function renders a 429 to the caller with 90 // metadata about when it's safe to retry. 91 func (m *Middleware) Handle(next http.Handler) http.Handler { 92 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 93 ctx := r.Context() 94 95 // Call the key function - if this fails, it's an internal server error. 96 key, err := m.keyFunc(r) 97 if err != nil { 98 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 99 return 100 } 101 102 // Take from the store. 103 limit, remaining, reset, ok, err := m.store.Take(ctx, key) 104 if err != nil { 105 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 106 return 107 } 108 109 resetTime := time.Unix(0, int64(reset)).UTC().Format(time.RFC1123) 110 111 // Set headers (we do this regardless of whether the request is permitted). 112 w.Header().Set(HeaderRateLimitLimit, strconv.FormatUint(limit, 10)) 113 w.Header().Set(HeaderRateLimitRemaining, strconv.FormatUint(remaining, 10)) 114 w.Header().Set(HeaderRateLimitReset, resetTime) 115 116 // Fail if there were no tokens remaining. 117 if !ok { 118 w.Header().Set(HeaderRetryAfter, resetTime) 119 http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) 120 return 121 } 122 123 // If we got this far, we're allowed to continue, so call the next middleware 124 // in the stack to continue processing. 125 next.ServeHTTP(w, r) 126 }) 127 }