code.vegaprotocol.io/vega@v0.79.0/datanode/ratelimit/ratelimit_test.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package ratelimit 17 18 import ( 19 "net/http" 20 "net/http/httptest" 21 "sync" 22 "testing" 23 "time" 24 25 "code.vegaprotocol.io/vega/logging" 26 27 "github.com/stretchr/testify/assert" 28 ) 29 30 func TestRateLimit_HTTPMiddleware(t *testing.T) { 31 mu := sync.Mutex{} 32 count := 0 33 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 34 mu.Lock() 35 defer mu.Unlock() 36 count++ 37 }) 38 39 req := httptest.NewRequest(http.MethodGet, "http://localhost:8080/test", nil) 40 41 cfg := NewDefaultConfig() 42 const burstSize = 20 43 cfg.Burst = burstSize 44 45 r := NewFromConfig(&cfg, logging.NewTestLogger()) 46 47 limiter := r.HTTPMiddleware(handler) 48 for i := 0; i < cfg.Burst; i++ { 49 res := httptest.NewRecorder() 50 limiter.ServeHTTP(res, req) 51 assert.Equal(t, http.StatusOK, res.Code) 52 assert.Equal(t, i+1, count) 53 } 54 55 for i := 0; i < cfg.Burst+1; i++ { 56 res := httptest.NewRecorder() 57 limiter.ServeHTTP(res, req) 58 assert.Equal(t, http.StatusTooManyRequests, res.Code) 59 assert.Equal(t, burstSize, count) 60 } 61 62 // We should have been banned after this so wait a second, then request again, 63 // the ban time remaining should not be empty. 64 time.Sleep(time.Second) 65 66 res := httptest.NewRecorder() 67 limiter.ServeHTTP(res, req) 68 assert.Equal(t, http.StatusForbidden, res.Code) 69 expiry := res.Header().Get("RateLimit-Retry-After") 70 assert.NotEmpty(t, expiry) 71 }