github.com/ulule/limiter/v3@v3.11.3-0.20230613131926-4cb9c1da4633/drivers/middleware/gin/middleware_test.go (about) 1 package gin_test 2 3 import ( 4 "net/http" 5 "net/http/httptest" 6 "strconv" 7 "sync" 8 "sync/atomic" 9 "testing" 10 11 libgin "github.com/gin-gonic/gin" 12 "github.com/stretchr/testify/require" 13 14 "github.com/ulule/limiter/v3" 15 "github.com/ulule/limiter/v3/drivers/middleware/gin" 16 "github.com/ulule/limiter/v3/drivers/store/memory" 17 ) 18 19 func TestHTTPMiddleware(t *testing.T) { 20 is := require.New(t) 21 libgin.SetMode(libgin.TestMode) 22 23 request, err := http.NewRequest("GET", "/", nil) 24 is.NoError(err) 25 is.NotNil(request) 26 27 store := memory.NewStore() 28 is.NotZero(store) 29 30 rate, err := limiter.NewRateFromFormatted("10-M") 31 is.NoError(err) 32 is.NotZero(rate) 33 34 middleware := gin.NewMiddleware(limiter.New(store, rate)) 35 is.NotZero(middleware) 36 37 router := libgin.New() 38 router.Use(middleware) 39 router.GET("/", func(c *libgin.Context) { 40 c.String(http.StatusOK, "hello") 41 }) 42 43 success := int64(10) 44 clients := int64(100) 45 46 // 47 // Sequential 48 // 49 50 for i := int64(1); i <= clients; i++ { 51 52 resp := httptest.NewRecorder() 53 router.ServeHTTP(resp, request) 54 55 if i <= success { 56 is.Equal(resp.Code, http.StatusOK) 57 } else { 58 is.Equal(resp.Code, http.StatusTooManyRequests) 59 } 60 } 61 62 // 63 // Concurrent 64 // 65 66 store = memory.NewStore() 67 is.NotZero(store) 68 69 middleware = gin.NewMiddleware(limiter.New(store, rate)) 70 is.NotZero(middleware) 71 72 router = libgin.New() 73 router.Use(middleware) 74 router.GET("/", func(c *libgin.Context) { 75 c.String(http.StatusOK, "hello") 76 }) 77 78 wg := &sync.WaitGroup{} 79 counter := int64(0) 80 81 for i := int64(1); i <= clients; i++ { 82 wg.Add(1) 83 go func() { 84 85 resp := httptest.NewRecorder() 86 router.ServeHTTP(resp, request) 87 88 if resp.Code == http.StatusOK { 89 atomic.AddInt64(&counter, 1) 90 } 91 92 wg.Done() 93 }() 94 } 95 96 wg.Wait() 97 is.Equal(success, atomic.LoadInt64(&counter)) 98 99 // 100 // Custom KeyGetter 101 // 102 103 store = memory.NewStore() 104 is.NotZero(store) 105 106 counter = int64(0) 107 keyGetter := func(c *libgin.Context) string { 108 v := atomic.AddInt64(&counter, 1) 109 return strconv.FormatInt(v, 10) 110 } 111 112 middleware = gin.NewMiddleware(limiter.New(store, rate), gin.WithKeyGetter(keyGetter)) 113 is.NotZero(middleware) 114 115 router = libgin.New() 116 router.Use(middleware) 117 router.GET("/", func(c *libgin.Context) { 118 c.String(http.StatusOK, "hello") 119 }) 120 121 for i := int64(1); i <= clients; i++ { 122 resp := httptest.NewRecorder() 123 router.ServeHTTP(resp, request) 124 // We should always be ok as the key changes for each request 125 is.Equal(http.StatusOK, resp.Code, strconv.FormatInt(i, 10)) 126 } 127 128 // 129 // Test ExcludedKey 130 // 131 store = memory.NewStore() 132 is.NotZero(store) 133 counter = int64(0) 134 excludedKeyFn := func(key string) bool { 135 return key == "1" 136 } 137 middleware = gin.NewMiddleware(limiter.New(store, rate), 138 gin.WithKeyGetter(func(c *libgin.Context) string { 139 v := atomic.AddInt64(&counter, 1) 140 return strconv.FormatInt(v%2, 10) 141 }), 142 gin.WithExcludedKey(excludedKeyFn), 143 ) 144 is.NotZero(middleware) 145 146 router = libgin.New() 147 router.Use(middleware) 148 router.GET("/", func(c *libgin.Context) { 149 c.String(http.StatusOK, "hello") 150 }) 151 success = 20 152 for i := int64(1); i < clients; i++ { 153 resp := httptest.NewRecorder() 154 router.ServeHTTP(resp, request) 155 if i <= success || i%2 == 1 { 156 is.Equal(http.StatusOK, resp.Code, strconv.FormatInt(i, 10)) 157 } else { 158 is.Equal(resp.Code, http.StatusTooManyRequests) 159 } 160 } 161 }