github.com/lunarobliq/gophish@v0.8.1-0.20230523153303-93511002234d/middleware/ratelimit/ratelimit.go (about)

     1  package ratelimit
     2  
     3  import (
     4  	"net"
     5  	"net/http"
     6  	"sync"
     7  	"time"
     8  
     9  	log "github.com/gophish/gophish/logger"
    10  	"golang.org/x/time/rate"
    11  )
    12  
    13  // DefaultRequestsPerMinute is the number of requests to allow per minute.
    14  // Any requests over this interval will return a HTTP 429 error.
    15  const DefaultRequestsPerMinute = 5
    16  
    17  // DefaultCleanupInterval determines how frequently the cleanup routine
    18  // executes.
    19  const DefaultCleanupInterval = 1 * time.Minute
    20  
    21  // DefaultExpiry is the amount of time to track a bucket for a particular
    22  // visitor.
    23  const DefaultExpiry = 10 * time.Minute
    24  
    25  type bucket struct {
    26  	limiter  *rate.Limiter
    27  	lastSeen time.Time
    28  }
    29  
    30  // PostLimiter is a simple rate limiting middleware which only allows n POST
    31  // requests per minute.
    32  type PostLimiter struct {
    33  	visitors        map[string]*bucket
    34  	requestLimit    int
    35  	cleanupInterval time.Duration
    36  	expiry          time.Duration
    37  	sync.RWMutex
    38  }
    39  
    40  // PostLimiterOption is a functional option that allows callers to configure
    41  // the rate limiter.
    42  type PostLimiterOption func(*PostLimiter)
    43  
    44  // WithRequestsPerMinute sets the number of requests to allow per minute.
    45  func WithRequestsPerMinute(requestLimit int) PostLimiterOption {
    46  	return func(p *PostLimiter) {
    47  		p.requestLimit = requestLimit
    48  	}
    49  }
    50  
    51  // WithCleanupInterval sets the interval between cleaning up stale entries in
    52  // the rate limit client list
    53  func WithCleanupInterval(interval time.Duration) PostLimiterOption {
    54  	return func(p *PostLimiter) {
    55  		p.cleanupInterval = interval
    56  	}
    57  }
    58  
    59  // WithExpiry sets the amount of time to store client entries before they are
    60  // considered stale.
    61  func WithExpiry(expiry time.Duration) PostLimiterOption {
    62  	return func(p *PostLimiter) {
    63  		p.expiry = expiry
    64  	}
    65  }
    66  
    67  // NewPostLimiter returns a new instance of a PostLimiter
    68  func NewPostLimiter(opts ...PostLimiterOption) *PostLimiter {
    69  	limiter := &PostLimiter{
    70  		visitors:        make(map[string]*bucket),
    71  		requestLimit:    DefaultRequestsPerMinute,
    72  		cleanupInterval: DefaultCleanupInterval,
    73  		expiry:          DefaultExpiry,
    74  	}
    75  	for _, opt := range opts {
    76  		opt(limiter)
    77  	}
    78  	go limiter.pollCleanup()
    79  	return limiter
    80  }
    81  
    82  func (limiter *PostLimiter) pollCleanup() {
    83  	ticker := time.NewTicker(time.Duration(limiter.cleanupInterval) * time.Second)
    84  	for range ticker.C {
    85  		limiter.Cleanup()
    86  	}
    87  }
    88  
    89  // Cleanup removes any buckets that were last seen past the configured expiry.
    90  func (limiter *PostLimiter) Cleanup() {
    91  	limiter.Lock()
    92  	defer limiter.Unlock()
    93  	for ip, bucket := range limiter.visitors {
    94  		if time.Since(bucket.lastSeen) >= limiter.expiry {
    95  			delete(limiter.visitors, ip)
    96  		}
    97  	}
    98  }
    99  
   100  func (limiter *PostLimiter) addBucket(ip string) *bucket {
   101  	limiter.Lock()
   102  	defer limiter.Unlock()
   103  	limit := rate.NewLimiter(rate.Every(time.Minute/time.Duration(limiter.requestLimit)), limiter.requestLimit)
   104  	b := &bucket{
   105  		limiter: limit,
   106  	}
   107  	limiter.visitors[ip] = b
   108  	return b
   109  }
   110  
   111  func (limiter *PostLimiter) allow(ip string) bool {
   112  	// Check if we have a limiter already active for this clientIP
   113  	limiter.RLock()
   114  	bucket, exists := limiter.visitors[ip]
   115  	limiter.RUnlock()
   116  	if !exists {
   117  		bucket = limiter.addBucket(ip)
   118  	}
   119  	// Update the lastSeen for this bucket to assist with cleanup
   120  	limiter.Lock()
   121  	defer limiter.Unlock()
   122  	bucket.lastSeen = time.Now()
   123  	return bucket.limiter.Allow()
   124  }
   125  
   126  // Limit enforces the configured rate limit for POST requests.
   127  //
   128  // TODO: Change the return value to an http.Handler when we clean up the
   129  // way Gophish routing is done.
   130  func (limiter *PostLimiter) Limit(next http.Handler) http.HandlerFunc {
   131  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   132  		clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
   133  		if err != nil {
   134  			clientIP = r.RemoteAddr
   135  		}
   136  		if r.Method == http.MethodPost && !limiter.allow(clientIP) {
   137  			log.Error("")
   138  			http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
   139  			return
   140  		}
   141  		next.ServeHTTP(w, r)
   142  	})
   143  }