github.com/gofiber/fiber/v2@v2.47.0/middleware/limiter/limiter_sliding.go (about)

     1  package limiter
     2  
     3  import (
     4  	"strconv"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"github.com/gofiber/fiber/v2"
    10  	"github.com/gofiber/fiber/v2/utils"
    11  )
    12  
    13  type SlidingWindow struct{}
    14  
    15  // New creates a new sliding window middleware handler
    16  func (SlidingWindow) New(cfg Config) fiber.Handler {
    17  	var (
    18  		// Limiter variables
    19  		mux        = &sync.RWMutex{}
    20  		max        = strconv.Itoa(cfg.Max)
    21  		expiration = uint64(cfg.Expiration.Seconds())
    22  	)
    23  
    24  	// Create manager to simplify storage operations ( see manager.go )
    25  	manager := newManager(cfg.Storage)
    26  
    27  	// Update timestamp every second
    28  	utils.StartTimeStampUpdater()
    29  
    30  	// Return new handler
    31  	return func(c *fiber.Ctx) error {
    32  		// Don't execute middleware if Next returns true
    33  		if cfg.Next != nil && cfg.Next(c) {
    34  			return c.Next()
    35  		}
    36  
    37  		// Get key from request
    38  		key := cfg.KeyGenerator(c)
    39  
    40  		// Lock entry
    41  		mux.Lock()
    42  
    43  		// Get entry from pool and release when finished
    44  		e := manager.get(key)
    45  
    46  		// Get timestamp
    47  		ts := uint64(atomic.LoadUint32(&utils.Timestamp))
    48  
    49  		// Set expiration if entry does not exist
    50  		if e.exp == 0 {
    51  			e.exp = ts + expiration
    52  		} else if ts >= e.exp {
    53  			// The entry has expired, handle the expiration.
    54  			// Set the prevHits to the current hits and reset the hits to 0.
    55  			e.prevHits = e.currHits
    56  
    57  			// Reset the current hits to 0.
    58  			e.currHits = 0
    59  
    60  			// Check how much into the current window it currently is and sets the
    61  			// expiry based on that, otherwise this would only reset on
    62  			// the next request and not show the correct expiry.
    63  			elapsed := ts - e.exp
    64  			if elapsed >= expiration {
    65  				e.exp = ts + expiration
    66  			} else {
    67  				e.exp = ts + expiration - elapsed
    68  			}
    69  		}
    70  
    71  		// Increment hits
    72  		e.currHits++
    73  
    74  		// Calculate when it resets in seconds
    75  		resetInSec := e.exp - ts
    76  
    77  		// weight = time until current window reset / total window length
    78  		weight := float64(resetInSec) / float64(expiration)
    79  
    80  		// rate = request count in previous window - weight + request count in current window
    81  		rate := int(float64(e.prevHits)*weight) + e.currHits
    82  
    83  		// Calculate how many hits can be made based on the current rate
    84  		remaining := cfg.Max - rate
    85  
    86  		// Update storage. Garbage collect when the next window ends.
    87  		// |--------------------------|--------------------------|
    88  		//               ^            ^               ^          ^
    89  		//              ts         e.exp   End sample window   End next window
    90  		//               <------------>
    91  		// 				   resetInSec
    92  		// resetInSec = e.exp - ts - time until end of current window.
    93  		// duration + expiration = end of next window.
    94  		// Because we don't want to garbage collect in the middle of a window
    95  		// we add the expiration to the duration.
    96  		// Otherwise after the end of "sample window", attackers could launch
    97  		// a new request with the full window length.
    98  		manager.set(key, e, time.Duration(resetInSec+expiration)*time.Second)
    99  
   100  		// Unlock entry
   101  		mux.Unlock()
   102  
   103  		// Check if hits exceed the cfg.Max
   104  		if remaining < 0 {
   105  			// Return response with Retry-After header
   106  			// https://tools.ietf.org/html/rfc6584
   107  			c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
   108  
   109  			// Call LimitReached handler
   110  			return cfg.LimitReached(c)
   111  		}
   112  
   113  		// Continue stack for reaching c.Response().StatusCode()
   114  		// Store err for returning
   115  		err := c.Next()
   116  
   117  		// Check for SkipFailedRequests and SkipSuccessfulRequests
   118  		if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
   119  			(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
   120  			// Lock entry
   121  			mux.Lock()
   122  			e = manager.get(key)
   123  			e.currHits--
   124  			remaining++
   125  			manager.set(key, e, cfg.Expiration)
   126  			// Unlock entry
   127  			mux.Unlock()
   128  		}
   129  
   130  		// We can continue, update RateLimit headers
   131  		c.Set(xRateLimitLimit, max)
   132  		c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
   133  		c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
   134  
   135  		return err
   136  	}
   137  }