github.com/ulule/limiter/v3@v3.11.3-0.20230613131926-4cb9c1da4633/drivers/middleware/stdlib/middleware_test.go (about)

     1  package stdlib_test
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"sync"
     7  	"sync/atomic"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/ulule/limiter/v3"
    13  	"github.com/ulule/limiter/v3/drivers/middleware/stdlib"
    14  	"github.com/ulule/limiter/v3/drivers/store/memory"
    15  )
    16  
    17  func TestHTTPMiddleware(t *testing.T) {
    18  	is := require.New(t)
    19  
    20  	request, err := http.NewRequest("GET", "/", nil)
    21  	is.NoError(err)
    22  	is.NotNil(request)
    23  
    24  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    25  		_, thr := w.Write([]byte("hello"))
    26  		if thr != nil {
    27  			panic(thr)
    28  		}
    29  	})
    30  
    31  	store := memory.NewStore()
    32  	is.NotZero(store)
    33  
    34  	rate, err := limiter.NewRateFromFormatted("10-M")
    35  	is.NoError(err)
    36  	is.NotZero(rate)
    37  
    38  	middleware := stdlib.NewMiddleware(limiter.New(store, rate)).Handler(handler)
    39  	is.NotZero(middleware)
    40  
    41  	success := int64(10)
    42  	clients := int64(100)
    43  
    44  	//
    45  	// Sequential
    46  	//
    47  
    48  	for i := int64(1); i <= clients; i++ {
    49  
    50  		resp := httptest.NewRecorder()
    51  		middleware.ServeHTTP(resp, request)
    52  
    53  		if i <= success {
    54  			is.Equal(resp.Code, http.StatusOK)
    55  		} else {
    56  			is.Equal(resp.Code, http.StatusTooManyRequests)
    57  		}
    58  	}
    59  
    60  	//
    61  	// Concurrent
    62  	//
    63  
    64  	store = memory.NewStore()
    65  	is.NotZero(store)
    66  
    67  	middleware = stdlib.NewMiddleware(limiter.New(store, rate)).Handler(handler)
    68  	is.NotZero(middleware)
    69  
    70  	wg := &sync.WaitGroup{}
    71  	counter := int64(0)
    72  
    73  	for i := int64(1); i <= clients; i++ {
    74  		wg.Add(1)
    75  		go func() {
    76  
    77  			resp := httptest.NewRecorder()
    78  			middleware.ServeHTTP(resp, request)
    79  
    80  			if resp.Code == http.StatusOK {
    81  				atomic.AddInt64(&counter, 1)
    82  			}
    83  
    84  			wg.Done()
    85  		}()
    86  	}
    87  
    88  	wg.Wait()
    89  	is.Equal(success, atomic.LoadInt64(&counter))
    90  
    91  }