github.com/ulule/limiter/v3@v3.11.3-0.20230613131926-4cb9c1da4633/drivers/middleware/fasthttp/middleware_test.go (about) 1 package fasthttp_test 2 3 import ( 4 "net" 5 "strconv" 6 "sync" 7 "sync/atomic" 8 "testing" 9 10 "github.com/stretchr/testify/require" 11 libfasthttp "github.com/valyala/fasthttp" 12 "github.com/valyala/fasthttp/fasthttputil" 13 14 "github.com/ulule/limiter/v3" 15 "github.com/ulule/limiter/v3/drivers/middleware/fasthttp" 16 "github.com/ulule/limiter/v3/drivers/store/memory" 17 ) 18 19 // nolint: gocyclo 20 func TestFasthttpMiddleware(t *testing.T) { 21 is := require.New(t) 22 23 store := memory.NewStore() 24 is.NotZero(store) 25 26 rate, err := limiter.NewRateFromFormatted("10-M") 27 is.NoError(err) 28 is.NotZero(rate) 29 30 middleware := fasthttp.NewMiddleware(limiter.New(store, rate)) 31 32 requestHandler := func(ctx *libfasthttp.RequestCtx) { 33 switch string(ctx.Path()) { 34 case "/": 35 ctx.SetStatusCode(libfasthttp.StatusOK) 36 ctx.SetBodyString("hello") 37 } 38 } 39 40 success := int64(10) 41 clients := int64(100) 42 43 // 44 // Sequential 45 // 46 47 for i := int64(1); i <= clients; i++ { 48 resp := libfasthttp.AcquireResponse() 49 req := libfasthttp.AcquireRequest() 50 req.Header.SetHost("localhost:8081") 51 req.Header.SetRequestURI("/") 52 err := serve(middleware.Handle(requestHandler), req, resp) 53 is.NoError(err) 54 55 if i <= success { 56 is.Equal(resp.StatusCode(), libfasthttp.StatusOK) 57 } else { 58 is.Equal(resp.StatusCode(), libfasthttp.StatusTooManyRequests) 59 } 60 } 61 62 // 63 // Concurrent 64 // 65 66 store = memory.NewStore() 67 is.NotZero(store) 68 69 middleware = fasthttp.NewMiddleware(limiter.New(store, rate)) 70 71 requestHandler = func(ctx *libfasthttp.RequestCtx) { 72 switch string(ctx.Path()) { 73 case "/": 74 ctx.SetStatusCode(libfasthttp.StatusOK) 75 ctx.SetBodyString("hello") 76 } 77 } 78 79 wg := &sync.WaitGroup{} 80 counter := int64(0) 81 82 for i := int64(1); i <= clients; i++ { 83 wg.Add(1) 84 85 go func() { 86 resp := libfasthttp.AcquireResponse() 87 req := libfasthttp.AcquireRequest() 88 req.Header.SetHost("localhost:8081") 89 req.Header.SetRequestURI("/") 90 err := serve(middleware.Handle(requestHandler), req, resp) 91 is.NoError(err) 92 93 if resp.StatusCode() == libfasthttp.StatusOK { 94 atomic.AddInt64(&counter, 1) 95 } 96 97 wg.Done() 98 }() 99 } 100 101 wg.Wait() 102 is.Equal(success, atomic.LoadInt64(&counter)) 103 104 // 105 // Custom KeyGetter 106 // 107 108 store = memory.NewStore() 109 is.NotZero(store) 110 111 counter = int64(0) 112 keyGetter := func(c *libfasthttp.RequestCtx) string { 113 v := atomic.AddInt64(&counter, 1) 114 return strconv.FormatInt(v, 10) 115 } 116 117 middleware = fasthttp.NewMiddleware(limiter.New(store, rate), fasthttp.WithKeyGetter(keyGetter)) 118 is.NotZero(middleware) 119 120 requestHandler = func(ctx *libfasthttp.RequestCtx) { 121 switch string(ctx.Path()) { 122 case "/": 123 ctx.SetStatusCode(libfasthttp.StatusOK) 124 ctx.SetBodyString("hello") 125 } 126 } 127 128 for i := int64(1); i <= clients; i++ { 129 resp := libfasthttp.AcquireResponse() 130 req := libfasthttp.AcquireRequest() 131 req.Header.SetHost("localhost:8081") 132 req.Header.SetRequestURI("/") 133 err := serve(middleware.Handle(requestHandler), req, resp) 134 is.NoError(err) 135 is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10)) 136 } 137 138 // 139 // Test ExcludedKey 140 // 141 142 store = memory.NewStore() 143 is.NotZero(store) 144 145 counter = int64(0) 146 keyGetterHandler := func(c *libfasthttp.RequestCtx) string { 147 v := atomic.AddInt64(&counter, 1) 148 return strconv.FormatInt(v%2, 10) 149 } 150 excludedKeyHandler := func(key string) bool { 151 return key == "1" 152 } 153 154 middleware = fasthttp.NewMiddleware(limiter.New(store, rate), 155 fasthttp.WithKeyGetter(keyGetterHandler), fasthttp.WithExcludedKey(excludedKeyHandler)) 156 is.NotZero(middleware) 157 158 requestHandler = func(ctx *libfasthttp.RequestCtx) { 159 switch string(ctx.Path()) { 160 case "/": 161 ctx.SetStatusCode(libfasthttp.StatusOK) 162 ctx.SetBodyString("hello") 163 } 164 } 165 166 success = 20 167 for i := int64(1); i <= clients; i++ { 168 resp := libfasthttp.AcquireResponse() 169 req := libfasthttp.AcquireRequest() 170 req.Header.SetHost("localhost:8081") 171 req.Header.SetRequestURI("/") 172 err := serve(middleware.Handle(requestHandler), req, resp) 173 is.NoError(err) 174 if i <= success || i%2 == 1 { 175 is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10)) 176 } else { 177 is.Equal(libfasthttp.StatusTooManyRequests, resp.StatusCode(), strconv.FormatInt(i, 10)) 178 } 179 } 180 } 181 182 func serve(handler libfasthttp.RequestHandler, req *libfasthttp.Request, res *libfasthttp.Response) error { 183 ln := fasthttputil.NewInmemoryListener() 184 defer func() { 185 err := ln.Close() 186 if err != nil { 187 panic(err) 188 } 189 }() 190 191 go func() { 192 err := libfasthttp.Serve(ln, handler) 193 if err != nil { 194 panic(err) 195 } 196 }() 197 198 client := libfasthttp.Client{ 199 Dial: func(addr string) (net.Conn, error) { 200 return ln.Dial() 201 }, 202 } 203 204 return client.Do(req, res) 205 }