github.com/lunarobliq/gophish@v0.8.1-0.20230523153303-93511002234d/middleware/ratelimit/ratelimit_test.go (about) 1 package ratelimit 2 3 import ( 4 "net/http" 5 "net/http/httptest" 6 "testing" 7 ) 8 9 var successHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 10 w.Write([]byte("ok")) 11 }) 12 13 func reachLimit(t *testing.T, handler http.Handler, limit int) { 14 // Make `expected` requests and ensure that each return a successful 15 // response. 16 r := httptest.NewRequest(http.MethodPost, "/", nil) 17 r.RemoteAddr = "127.0.0.1:" 18 for i := 0; i < limit; i++ { 19 w := httptest.NewRecorder() 20 handler.ServeHTTP(w, r) 21 if w.Code != http.StatusOK { 22 t.Fatalf("no 200 on req %d got %d", i, w.Code) 23 } 24 } 25 // Then, makes another request to ensure it returns the 429 26 // status. 27 w := httptest.NewRecorder() 28 handler.ServeHTTP(w, r) 29 if w.Code != http.StatusTooManyRequests { 30 t.Fatalf("no 429") 31 } 32 } 33 34 func TestRateLimitEnforcement(t *testing.T) { 35 expectedLimit := 3 36 limiter := NewPostLimiter(WithRequestsPerMinute(expectedLimit)) 37 handler := limiter.Limit(successHandler) 38 reachLimit(t, handler, expectedLimit) 39 } 40 41 func TestRateLimitCleanup(t *testing.T) { 42 expectedLimit := 3 43 limiter := NewPostLimiter(WithRequestsPerMinute(expectedLimit)) 44 handler := limiter.Limit(successHandler) 45 reachLimit(t, handler, expectedLimit) 46 47 // Set the timeout to be 48 bucket, exists := limiter.visitors["127.0.0.1"] 49 if !exists { 50 t.Fatalf("doesn't exist for some reason") 51 } 52 bucket.lastSeen = bucket.lastSeen.Add(-limiter.expiry) 53 limiter.Cleanup() 54 _, exists = limiter.visitors["127.0.0.1"] 55 if exists { 56 t.Fatalf("exists for some reason") 57 } 58 reachLimit(t, handler, expectedLimit) 59 }