github.com/mvg-fi/go-limiter@v0.1.1/httplimit/middleware.go (about)

     1  // Package httplimit provides middleware for rate limiting HTTP handlers.
     2  //
     3  // The implementation is designed to work with Go's built-in http.Handler and
     4  // http.HandlerFunc interfaces, so it will also work with any popular web
     5  // frameworks that support middleware with these properties.
     6  package httplimit
     7  
     8  import (
     9  	"fmt"
    10  	"net"
    11  	"net/http"
    12  	"strconv"
    13  	"time"
    14  
    15  	"github.com/mvg-fi/go-limiter"
    16  )
    17  
    18  const (
    19  	// HeaderRateLimitLimit, HeaderRateLimitRemaining, and HeaderRateLimitReset
    20  	// are the recommended return header values from IETF on rate limiting. Reset
    21  	// is in UTC time.
    22  	HeaderRateLimitLimit     = "X-RateLimit-Limit"
    23  	HeaderRateLimitRemaining = "X-RateLimit-Remaining"
    24  	HeaderRateLimitReset     = "X-RateLimit-Reset"
    25  
    26  	// HeaderRetryAfter is the header used to indicate when a client should retry
    27  	// requests (when the rate limit expires), in UTC time.
    28  	HeaderRetryAfter = "Retry-After"
    29  )
    30  
    31  // KeyFunc is a function that accepts an http request and returns a string key
    32  // that uniquely identifies this request for the purpose of rate limiting.
    33  //
    34  // KeyFuncs are called on each request, so be mindful of performance and
    35  // implement caching where possible. If a KeyFunc returns an error, the HTTP
    36  // handler will return Internal Server Error and will NOT take from the limiter
    37  // store.
    38  type KeyFunc func(r *http.Request) (string, error)
    39  
    40  // IPKeyFunc returns a function that keys data based on the incoming requests IP
    41  // address. By default this uses the RemoteAddr, but you can also specify a list
    42  // of headers which will be checked for an IP address first (e.g.
    43  // "X-Forwarded-For"). Headers are retrieved using Header.Get(), which means
    44  // they are case insensitive.
    45  func IPKeyFunc(headers ...string) KeyFunc {
    46  	return func(r *http.Request) (string, error) {
    47  		for _, h := range headers {
    48  			if v := r.Header.Get(h); v != "" {
    49  				return v, nil
    50  			}
    51  		}
    52  
    53  		ip, _, err := net.SplitHostPort(r.RemoteAddr)
    54  		if err != nil {
    55  			return "", err
    56  		}
    57  		return ip, nil
    58  	}
    59  }
    60  
    61  // Middleware is a handler/mux that can wrap other middlware to implement HTTP
    62  // rate limiting. It can rate limit based on an arbitrary KeyFunc, and supports
    63  // anything that implements limiter.StoreWithContext.
    64  type Middleware struct {
    65  	store   limiter.Store
    66  	keyFunc KeyFunc
    67  }
    68  
    69  // NewMiddleware creates a new middleware suitable for use as an HTTP handler.
    70  // This function returns an error if either the Store or KeyFunc are nil.
    71  func NewMiddleware(s limiter.Store, f KeyFunc) (*Middleware, error) {
    72  	if s == nil {
    73  		return nil, fmt.Errorf("store cannot be nil")
    74  	}
    75  
    76  	if f == nil {
    77  		return nil, fmt.Errorf("key function cannot be nil")
    78  	}
    79  
    80  	return &Middleware{
    81  		store:   s,
    82  		keyFunc: f,
    83  	}, nil
    84  }
    85  
    86  // Handle returns the HTTP handler as a middleware. This handler calls Take() on
    87  // the store and sets the common rate limiting headers. If the take is
    88  // successful, the remaining middleware is called. If take is unsuccessful, the
    89  // middleware chain is halted and the function renders a 429 to the caller with
    90  // metadata about when it's safe to retry.
    91  func (m *Middleware) Handle(next http.Handler) http.Handler {
    92  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    93  		ctx := r.Context()
    94  
    95  		// Call the key function - if this fails, it's an internal server error.
    96  		key, err := m.keyFunc(r)
    97  		if err != nil {
    98  			http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
    99  			return
   100  		}
   101  
   102  		// Take from the store.
   103  		limit, remaining, reset, ok, err := m.store.Take(ctx, key)
   104  		if err != nil {
   105  			http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   106  			return
   107  		}
   108  
   109  		resetTime := time.Unix(0, int64(reset)).UTC().Format(time.RFC1123)
   110  
   111  		// Set headers (we do this regardless of whether the request is permitted).
   112  		w.Header().Set(HeaderRateLimitLimit, strconv.FormatUint(limit, 10))
   113  		w.Header().Set(HeaderRateLimitRemaining, strconv.FormatUint(remaining, 10))
   114  		w.Header().Set(HeaderRateLimitReset, resetTime)
   115  
   116  		// Fail if there were no tokens remaining.
   117  		if !ok {
   118  			w.Header().Set(HeaderRetryAfter, resetTime)
   119  			http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
   120  			return
   121  		}
   122  
   123  		// If we got this far, we're allowed to continue, so call the next middleware
   124  		// in the stack to continue processing.
   125  		next.ServeHTTP(w, r)
   126  	})
   127  }