github.com/ulule/limiter/v3@v3.11.3-0.20230613131926-4cb9c1da4633/drivers/middleware/stdlib/middleware.go (about)

     1  package stdlib
     2  
     3  import (
     4  	"net/http"
     5  	"strconv"
     6  
     7  	"github.com/ulule/limiter/v3"
     8  )
     9  
    10  // Middleware is the middleware for basic http.Handler.
    11  type Middleware struct {
    12  	Limiter        *limiter.Limiter
    13  	OnError        ErrorHandler
    14  	OnLimitReached LimitReachedHandler
    15  	KeyGetter      KeyGetter
    16  	ExcludedKey    func(string) bool
    17  }
    18  
    19  // NewMiddleware return a new instance of a basic HTTP middleware.
    20  func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware {
    21  	middleware := &Middleware{
    22  		Limiter:        limiter,
    23  		OnError:        DefaultErrorHandler,
    24  		OnLimitReached: DefaultLimitReachedHandler,
    25  		KeyGetter:      DefaultKeyGetter(limiter),
    26  		ExcludedKey:    nil,
    27  	}
    28  
    29  	for _, option := range options {
    30  		option.apply(middleware)
    31  	}
    32  
    33  	return middleware
    34  }
    35  
    36  // Handler handles a HTTP request.
    37  func (middleware *Middleware) Handler(h http.Handler) http.Handler {
    38  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    39  		key := middleware.KeyGetter(r)
    40  		if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) {
    41  			h.ServeHTTP(w, r)
    42  			return
    43  		}
    44  
    45  		context, err := middleware.Limiter.Get(r.Context(), key)
    46  		if err != nil {
    47  			middleware.OnError(w, r, err)
    48  			return
    49  		}
    50  
    51  		w.Header().Add("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10))
    52  		w.Header().Add("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10))
    53  		w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10))
    54  
    55  		if context.Reached {
    56  			middleware.OnLimitReached(w, r)
    57  			return
    58  		}
    59  
    60  		h.ServeHTTP(w, r)
    61  	})
    62  }