github.com/ulule/limiter/v3@v3.11.3-0.20230613131926-4cb9c1da4633/drivers/middleware/gin/middleware_test.go (about)

     1  package gin_test
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"strconv"
     7  	"sync"
     8  	"sync/atomic"
     9  	"testing"
    10  
    11  	libgin "github.com/gin-gonic/gin"
    12  	"github.com/stretchr/testify/require"
    13  
    14  	"github.com/ulule/limiter/v3"
    15  	"github.com/ulule/limiter/v3/drivers/middleware/gin"
    16  	"github.com/ulule/limiter/v3/drivers/store/memory"
    17  )
    18  
    19  func TestHTTPMiddleware(t *testing.T) {
    20  	is := require.New(t)
    21  	libgin.SetMode(libgin.TestMode)
    22  
    23  	request, err := http.NewRequest("GET", "/", nil)
    24  	is.NoError(err)
    25  	is.NotNil(request)
    26  
    27  	store := memory.NewStore()
    28  	is.NotZero(store)
    29  
    30  	rate, err := limiter.NewRateFromFormatted("10-M")
    31  	is.NoError(err)
    32  	is.NotZero(rate)
    33  
    34  	middleware := gin.NewMiddleware(limiter.New(store, rate))
    35  	is.NotZero(middleware)
    36  
    37  	router := libgin.New()
    38  	router.Use(middleware)
    39  	router.GET("/", func(c *libgin.Context) {
    40  		c.String(http.StatusOK, "hello")
    41  	})
    42  
    43  	success := int64(10)
    44  	clients := int64(100)
    45  
    46  	//
    47  	// Sequential
    48  	//
    49  
    50  	for i := int64(1); i <= clients; i++ {
    51  
    52  		resp := httptest.NewRecorder()
    53  		router.ServeHTTP(resp, request)
    54  
    55  		if i <= success {
    56  			is.Equal(resp.Code, http.StatusOK)
    57  		} else {
    58  			is.Equal(resp.Code, http.StatusTooManyRequests)
    59  		}
    60  	}
    61  
    62  	//
    63  	// Concurrent
    64  	//
    65  
    66  	store = memory.NewStore()
    67  	is.NotZero(store)
    68  
    69  	middleware = gin.NewMiddleware(limiter.New(store, rate))
    70  	is.NotZero(middleware)
    71  
    72  	router = libgin.New()
    73  	router.Use(middleware)
    74  	router.GET("/", func(c *libgin.Context) {
    75  		c.String(http.StatusOK, "hello")
    76  	})
    77  
    78  	wg := &sync.WaitGroup{}
    79  	counter := int64(0)
    80  
    81  	for i := int64(1); i <= clients; i++ {
    82  		wg.Add(1)
    83  		go func() {
    84  
    85  			resp := httptest.NewRecorder()
    86  			router.ServeHTTP(resp, request)
    87  
    88  			if resp.Code == http.StatusOK {
    89  				atomic.AddInt64(&counter, 1)
    90  			}
    91  
    92  			wg.Done()
    93  		}()
    94  	}
    95  
    96  	wg.Wait()
    97  	is.Equal(success, atomic.LoadInt64(&counter))
    98  
    99  	//
   100  	// Custom KeyGetter
   101  	//
   102  
   103  	store = memory.NewStore()
   104  	is.NotZero(store)
   105  
   106  	counter = int64(0)
   107  	keyGetter := func(c *libgin.Context) string {
   108  		v := atomic.AddInt64(&counter, 1)
   109  		return strconv.FormatInt(v, 10)
   110  	}
   111  
   112  	middleware = gin.NewMiddleware(limiter.New(store, rate), gin.WithKeyGetter(keyGetter))
   113  	is.NotZero(middleware)
   114  
   115  	router = libgin.New()
   116  	router.Use(middleware)
   117  	router.GET("/", func(c *libgin.Context) {
   118  		c.String(http.StatusOK, "hello")
   119  	})
   120  
   121  	for i := int64(1); i <= clients; i++ {
   122  		resp := httptest.NewRecorder()
   123  		router.ServeHTTP(resp, request)
   124  		// We should always be ok as the key changes for each request
   125  		is.Equal(http.StatusOK, resp.Code, strconv.FormatInt(i, 10))
   126  	}
   127  
   128  	//
   129  	// Test ExcludedKey
   130  	//
   131  	store = memory.NewStore()
   132  	is.NotZero(store)
   133  	counter = int64(0)
   134  	excludedKeyFn := func(key string) bool {
   135  		return key == "1"
   136  	}
   137  	middleware = gin.NewMiddleware(limiter.New(store, rate),
   138  		gin.WithKeyGetter(func(c *libgin.Context) string {
   139  			v := atomic.AddInt64(&counter, 1)
   140  			return strconv.FormatInt(v%2, 10)
   141  		}),
   142  		gin.WithExcludedKey(excludedKeyFn),
   143  	)
   144  	is.NotZero(middleware)
   145  
   146  	router = libgin.New()
   147  	router.Use(middleware)
   148  	router.GET("/", func(c *libgin.Context) {
   149  		c.String(http.StatusOK, "hello")
   150  	})
   151  	success = 20
   152  	for i := int64(1); i < clients; i++ {
   153  		resp := httptest.NewRecorder()
   154  		router.ServeHTTP(resp, request)
   155  		if i <= success || i%2 == 1 {
   156  			is.Equal(http.StatusOK, resp.Code, strconv.FormatInt(i, 10))
   157  		} else {
   158  			is.Equal(resp.Code, http.StatusTooManyRequests)
   159  		}
   160  	}
   161  }