github.com/ulule/limiter/v3@v3.11.3-0.20230613131926-4cb9c1da4633/drivers/middleware/fasthttp/middleware_test.go (about)

     1  package fasthttp_test
     2  
     3  import (
     4  	"net"
     5  	"strconv"
     6  	"sync"
     7  	"sync/atomic"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/require"
    11  	libfasthttp "github.com/valyala/fasthttp"
    12  	"github.com/valyala/fasthttp/fasthttputil"
    13  
    14  	"github.com/ulule/limiter/v3"
    15  	"github.com/ulule/limiter/v3/drivers/middleware/fasthttp"
    16  	"github.com/ulule/limiter/v3/drivers/store/memory"
    17  )
    18  
    19  // nolint: gocyclo
    20  func TestFasthttpMiddleware(t *testing.T) {
    21  	is := require.New(t)
    22  
    23  	store := memory.NewStore()
    24  	is.NotZero(store)
    25  
    26  	rate, err := limiter.NewRateFromFormatted("10-M")
    27  	is.NoError(err)
    28  	is.NotZero(rate)
    29  
    30  	middleware := fasthttp.NewMiddleware(limiter.New(store, rate))
    31  
    32  	requestHandler := func(ctx *libfasthttp.RequestCtx) {
    33  		switch string(ctx.Path()) {
    34  		case "/":
    35  			ctx.SetStatusCode(libfasthttp.StatusOK)
    36  			ctx.SetBodyString("hello")
    37  		}
    38  	}
    39  
    40  	success := int64(10)
    41  	clients := int64(100)
    42  
    43  	//
    44  	// Sequential
    45  	//
    46  
    47  	for i := int64(1); i <= clients; i++ {
    48  		resp := libfasthttp.AcquireResponse()
    49  		req := libfasthttp.AcquireRequest()
    50  		req.Header.SetHost("localhost:8081")
    51  		req.Header.SetRequestURI("/")
    52  		err := serve(middleware.Handle(requestHandler), req, resp)
    53  		is.NoError(err)
    54  
    55  		if i <= success {
    56  			is.Equal(resp.StatusCode(), libfasthttp.StatusOK)
    57  		} else {
    58  			is.Equal(resp.StatusCode(), libfasthttp.StatusTooManyRequests)
    59  		}
    60  	}
    61  
    62  	//
    63  	// Concurrent
    64  	//
    65  
    66  	store = memory.NewStore()
    67  	is.NotZero(store)
    68  
    69  	middleware = fasthttp.NewMiddleware(limiter.New(store, rate))
    70  
    71  	requestHandler = func(ctx *libfasthttp.RequestCtx) {
    72  		switch string(ctx.Path()) {
    73  		case "/":
    74  			ctx.SetStatusCode(libfasthttp.StatusOK)
    75  			ctx.SetBodyString("hello")
    76  		}
    77  	}
    78  
    79  	wg := &sync.WaitGroup{}
    80  	counter := int64(0)
    81  
    82  	for i := int64(1); i <= clients; i++ {
    83  		wg.Add(1)
    84  
    85  		go func() {
    86  			resp := libfasthttp.AcquireResponse()
    87  			req := libfasthttp.AcquireRequest()
    88  			req.Header.SetHost("localhost:8081")
    89  			req.Header.SetRequestURI("/")
    90  			err := serve(middleware.Handle(requestHandler), req, resp)
    91  			is.NoError(err)
    92  
    93  			if resp.StatusCode() == libfasthttp.StatusOK {
    94  				atomic.AddInt64(&counter, 1)
    95  			}
    96  
    97  			wg.Done()
    98  		}()
    99  	}
   100  
   101  	wg.Wait()
   102  	is.Equal(success, atomic.LoadInt64(&counter))
   103  
   104  	//
   105  	// Custom KeyGetter
   106  	//
   107  
   108  	store = memory.NewStore()
   109  	is.NotZero(store)
   110  
   111  	counter = int64(0)
   112  	keyGetter := func(c *libfasthttp.RequestCtx) string {
   113  		v := atomic.AddInt64(&counter, 1)
   114  		return strconv.FormatInt(v, 10)
   115  	}
   116  
   117  	middleware = fasthttp.NewMiddleware(limiter.New(store, rate), fasthttp.WithKeyGetter(keyGetter))
   118  	is.NotZero(middleware)
   119  
   120  	requestHandler = func(ctx *libfasthttp.RequestCtx) {
   121  		switch string(ctx.Path()) {
   122  		case "/":
   123  			ctx.SetStatusCode(libfasthttp.StatusOK)
   124  			ctx.SetBodyString("hello")
   125  		}
   126  	}
   127  
   128  	for i := int64(1); i <= clients; i++ {
   129  		resp := libfasthttp.AcquireResponse()
   130  		req := libfasthttp.AcquireRequest()
   131  		req.Header.SetHost("localhost:8081")
   132  		req.Header.SetRequestURI("/")
   133  		err := serve(middleware.Handle(requestHandler), req, resp)
   134  		is.NoError(err)
   135  		is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10))
   136  	}
   137  
   138  	//
   139  	// Test ExcludedKey
   140  	//
   141  
   142  	store = memory.NewStore()
   143  	is.NotZero(store)
   144  
   145  	counter = int64(0)
   146  	keyGetterHandler := func(c *libfasthttp.RequestCtx) string {
   147  		v := atomic.AddInt64(&counter, 1)
   148  		return strconv.FormatInt(v%2, 10)
   149  	}
   150  	excludedKeyHandler := func(key string) bool {
   151  		return key == "1"
   152  	}
   153  
   154  	middleware = fasthttp.NewMiddleware(limiter.New(store, rate),
   155  		fasthttp.WithKeyGetter(keyGetterHandler), fasthttp.WithExcludedKey(excludedKeyHandler))
   156  	is.NotZero(middleware)
   157  
   158  	requestHandler = func(ctx *libfasthttp.RequestCtx) {
   159  		switch string(ctx.Path()) {
   160  		case "/":
   161  			ctx.SetStatusCode(libfasthttp.StatusOK)
   162  			ctx.SetBodyString("hello")
   163  		}
   164  	}
   165  
   166  	success = 20
   167  	for i := int64(1); i <= clients; i++ {
   168  		resp := libfasthttp.AcquireResponse()
   169  		req := libfasthttp.AcquireRequest()
   170  		req.Header.SetHost("localhost:8081")
   171  		req.Header.SetRequestURI("/")
   172  		err := serve(middleware.Handle(requestHandler), req, resp)
   173  		is.NoError(err)
   174  		if i <= success || i%2 == 1 {
   175  			is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10))
   176  		} else {
   177  			is.Equal(libfasthttp.StatusTooManyRequests, resp.StatusCode(), strconv.FormatInt(i, 10))
   178  		}
   179  	}
   180  }
   181  
   182  func serve(handler libfasthttp.RequestHandler, req *libfasthttp.Request, res *libfasthttp.Response) error {
   183  	ln := fasthttputil.NewInmemoryListener()
   184  	defer func() {
   185  		err := ln.Close()
   186  		if err != nil {
   187  			panic(err)
   188  		}
   189  	}()
   190  
   191  	go func() {
   192  		err := libfasthttp.Serve(ln, handler)
   193  		if err != nil {
   194  			panic(err)
   195  		}
   196  	}()
   197  
   198  	client := libfasthttp.Client{
   199  		Dial: func(addr string) (net.Conn, error) {
   200  			return ln.Dial()
   201  		},
   202  	}
   203  
   204  	return client.Do(req, res)
   205  }