github.com/lunarobliq/gophish@v0.8.1-0.20230523153303-93511002234d/middleware/ratelimit/ratelimit.go (about) 1 package ratelimit 2 3 import ( 4 "net" 5 "net/http" 6 "sync" 7 "time" 8 9 log "github.com/gophish/gophish/logger" 10 "golang.org/x/time/rate" 11 ) 12 13 // DefaultRequestsPerMinute is the number of requests to allow per minute. 14 // Any requests over this interval will return a HTTP 429 error. 15 const DefaultRequestsPerMinute = 5 16 17 // DefaultCleanupInterval determines how frequently the cleanup routine 18 // executes. 19 const DefaultCleanupInterval = 1 * time.Minute 20 21 // DefaultExpiry is the amount of time to track a bucket for a particular 22 // visitor. 23 const DefaultExpiry = 10 * time.Minute 24 25 type bucket struct { 26 limiter *rate.Limiter 27 lastSeen time.Time 28 } 29 30 // PostLimiter is a simple rate limiting middleware which only allows n POST 31 // requests per minute. 32 type PostLimiter struct { 33 visitors map[string]*bucket 34 requestLimit int 35 cleanupInterval time.Duration 36 expiry time.Duration 37 sync.RWMutex 38 } 39 40 // PostLimiterOption is a functional option that allows callers to configure 41 // the rate limiter. 42 type PostLimiterOption func(*PostLimiter) 43 44 // WithRequestsPerMinute sets the number of requests to allow per minute. 45 func WithRequestsPerMinute(requestLimit int) PostLimiterOption { 46 return func(p *PostLimiter) { 47 p.requestLimit = requestLimit 48 } 49 } 50 51 // WithCleanupInterval sets the interval between cleaning up stale entries in 52 // the rate limit client list 53 func WithCleanupInterval(interval time.Duration) PostLimiterOption { 54 return func(p *PostLimiter) { 55 p.cleanupInterval = interval 56 } 57 } 58 59 // WithExpiry sets the amount of time to store client entries before they are 60 // considered stale. 61 func WithExpiry(expiry time.Duration) PostLimiterOption { 62 return func(p *PostLimiter) { 63 p.expiry = expiry 64 } 65 } 66 67 // NewPostLimiter returns a new instance of a PostLimiter 68 func NewPostLimiter(opts ...PostLimiterOption) *PostLimiter { 69 limiter := &PostLimiter{ 70 visitors: make(map[string]*bucket), 71 requestLimit: DefaultRequestsPerMinute, 72 cleanupInterval: DefaultCleanupInterval, 73 expiry: DefaultExpiry, 74 } 75 for _, opt := range opts { 76 opt(limiter) 77 } 78 go limiter.pollCleanup() 79 return limiter 80 } 81 82 func (limiter *PostLimiter) pollCleanup() { 83 ticker := time.NewTicker(time.Duration(limiter.cleanupInterval) * time.Second) 84 for range ticker.C { 85 limiter.Cleanup() 86 } 87 } 88 89 // Cleanup removes any buckets that were last seen past the configured expiry. 90 func (limiter *PostLimiter) Cleanup() { 91 limiter.Lock() 92 defer limiter.Unlock() 93 for ip, bucket := range limiter.visitors { 94 if time.Since(bucket.lastSeen) >= limiter.expiry { 95 delete(limiter.visitors, ip) 96 } 97 } 98 } 99 100 func (limiter *PostLimiter) addBucket(ip string) *bucket { 101 limiter.Lock() 102 defer limiter.Unlock() 103 limit := rate.NewLimiter(rate.Every(time.Minute/time.Duration(limiter.requestLimit)), limiter.requestLimit) 104 b := &bucket{ 105 limiter: limit, 106 } 107 limiter.visitors[ip] = b 108 return b 109 } 110 111 func (limiter *PostLimiter) allow(ip string) bool { 112 // Check if we have a limiter already active for this clientIP 113 limiter.RLock() 114 bucket, exists := limiter.visitors[ip] 115 limiter.RUnlock() 116 if !exists { 117 bucket = limiter.addBucket(ip) 118 } 119 // Update the lastSeen for this bucket to assist with cleanup 120 limiter.Lock() 121 defer limiter.Unlock() 122 bucket.lastSeen = time.Now() 123 return bucket.limiter.Allow() 124 } 125 126 // Limit enforces the configured rate limit for POST requests. 127 // 128 // TODO: Change the return value to an http.Handler when we clean up the 129 // way Gophish routing is done. 130 func (limiter *PostLimiter) Limit(next http.Handler) http.HandlerFunc { 131 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 132 clientIP, _, err := net.SplitHostPort(r.RemoteAddr) 133 if err != nil { 134 clientIP = r.RemoteAddr 135 } 136 if r.Method == http.MethodPost && !limiter.allow(clientIP) { 137 log.Error("") 138 http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) 139 return 140 } 141 next.ServeHTTP(w, r) 142 }) 143 }