github.com/adharshmk96/stk@v1.2.3/pkg/middleware/rate_limiter.go (about)

     1  package middleware
     2  
     3  import (
     4  	"net/http"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/adharshmk96/stk/gsk"
     9  )
    10  
    11  type RateLimiter struct {
    12  	requestsPerInterval int
    13  	interval            time.Duration
    14  	accessCounter       map[string]int
    15  	mux                 *sync.Mutex
    16  	Middleware          gsk.Middleware
    17  }
    18  
    19  type RateLimiterConfig struct {
    20  	RequestsPerInterval int
    21  	Interval            time.Duration
    22  }
    23  
    24  func initConfig(config ...RateLimiterConfig) *RateLimiterConfig {
    25  	var initConfig *RateLimiterConfig
    26  	if len(config) == 0 {
    27  		initConfig = &RateLimiterConfig{}
    28  	} else {
    29  		initConfig = &config[0]
    30  	}
    31  
    32  	if initConfig.RequestsPerInterval == 0 {
    33  		initConfig.RequestsPerInterval = 5
    34  	}
    35  	if initConfig.Interval == 0 {
    36  		initConfig.Interval = 1 * time.Second
    37  	}
    38  
    39  	return initConfig
    40  }
    41  
    42  func NewRateLimiter(rlConfig ...RateLimiterConfig) *RateLimiter {
    43  	config := initConfig(rlConfig...)
    44  
    45  	rl := &RateLimiter{
    46  		requestsPerInterval: config.RequestsPerInterval,
    47  		interval:            config.Interval,
    48  		accessCounter:       make(map[string]int),
    49  		mux:                 &sync.Mutex{},
    50  	}
    51  
    52  	middleware := func(next gsk.HandlerFunc) gsk.HandlerFunc {
    53  		return func(c *gsk.Context) {
    54  			clientIP := c.Request.RemoteAddr
    55  			rl.mux.Lock()
    56  			defer rl.mux.Unlock()
    57  
    58  			if cnt, ok := rl.accessCounter[clientIP]; ok {
    59  				if cnt >= rl.requestsPerInterval {
    60  					c.Status(http.StatusTooManyRequests).JSONResponse(gsk.Map{
    61  						"error": "Too many requests. Please try again later.",
    62  					})
    63  					return
    64  				}
    65  				rl.accessCounter[clientIP]++
    66  			} else {
    67  				rl.accessCounter[clientIP] = 1
    68  				go func(ip string) {
    69  					time.Sleep(rl.interval)
    70  					rl.mux.Lock()
    71  					defer rl.mux.Unlock()
    72  					delete(rl.accessCounter, ip)
    73  				}(clientIP)
    74  			}
    75  
    76  			next(c)
    77  		}
    78  	}
    79  
    80  	rl.Middleware = middleware
    81  
    82  	return rl
    83  
    84  }