github.com/adharshmk96/stk@v1.2.3/pkg/middleware/rate_limiter.go (about) 1 package middleware 2 3 import ( 4 "net/http" 5 "sync" 6 "time" 7 8 "github.com/adharshmk96/stk/gsk" 9 ) 10 11 type RateLimiter struct { 12 requestsPerInterval int 13 interval time.Duration 14 accessCounter map[string]int 15 mux *sync.Mutex 16 Middleware gsk.Middleware 17 } 18 19 type RateLimiterConfig struct { 20 RequestsPerInterval int 21 Interval time.Duration 22 } 23 24 func initConfig(config ...RateLimiterConfig) *RateLimiterConfig { 25 var initConfig *RateLimiterConfig 26 if len(config) == 0 { 27 initConfig = &RateLimiterConfig{} 28 } else { 29 initConfig = &config[0] 30 } 31 32 if initConfig.RequestsPerInterval == 0 { 33 initConfig.RequestsPerInterval = 5 34 } 35 if initConfig.Interval == 0 { 36 initConfig.Interval = 1 * time.Second 37 } 38 39 return initConfig 40 } 41 42 func NewRateLimiter(rlConfig ...RateLimiterConfig) *RateLimiter { 43 config := initConfig(rlConfig...) 44 45 rl := &RateLimiter{ 46 requestsPerInterval: config.RequestsPerInterval, 47 interval: config.Interval, 48 accessCounter: make(map[string]int), 49 mux: &sync.Mutex{}, 50 } 51 52 middleware := func(next gsk.HandlerFunc) gsk.HandlerFunc { 53 return func(c *gsk.Context) { 54 clientIP := c.Request.RemoteAddr 55 rl.mux.Lock() 56 defer rl.mux.Unlock() 57 58 if cnt, ok := rl.accessCounter[clientIP]; ok { 59 if cnt >= rl.requestsPerInterval { 60 c.Status(http.StatusTooManyRequests).JSONResponse(gsk.Map{ 61 "error": "Too many requests. Please try again later.", 62 }) 63 return 64 } 65 rl.accessCounter[clientIP]++ 66 } else { 67 rl.accessCounter[clientIP] = 1 68 go func(ip string) { 69 time.Sleep(rl.interval) 70 rl.mux.Lock() 71 defer rl.mux.Unlock() 72 delete(rl.accessCounter, ip) 73 }(clientIP) 74 } 75 76 next(c) 77 } 78 } 79 80 rl.Middleware = middleware 81 82 return rl 83 84 }