gitlab.com/infor-cloud/martian-cloud/tharsis/go-limiter@v0.0.0-20230411193226-3247984d5abc/httplimit/middleware_test.go (about)

     1  package httplimit_test
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"strconv"
     8  	"testing"
     9  	"time"
    10  
    11  	"gitlab.com/infor-cloud/martian-cloud/tharsis/go-limiter/httplimit"
    12  	"gitlab.com/infor-cloud/martian-cloud/tharsis/go-limiter/memorystore"
    13  )
    14  
    15  func TestNewMiddleware(t *testing.T) {
    16  	t.Parallel()
    17  
    18  	cases := []struct {
    19  		name     string
    20  		tokens   uint64
    21  		interval time.Duration
    22  	}{
    23  		{
    24  			name:     "millisecond",
    25  			tokens:   5,
    26  			interval: 500 * time.Millisecond,
    27  		},
    28  		{
    29  			name:     "second",
    30  			tokens:   3,
    31  			interval: time.Second,
    32  		},
    33  	}
    34  
    35  	for _, tc := range cases {
    36  		tc := tc
    37  
    38  		t.Run(tc.name, func(t *testing.T) {
    39  			t.Parallel()
    40  
    41  			store, err := memorystore.New(&memorystore.Config{
    42  				Tokens:   tc.tokens,
    43  				Interval: tc.interval,
    44  			})
    45  			if err != nil {
    46  				t.Fatal(err)
    47  			}
    48  
    49  			middleware, err := httplimit.NewMiddleware(store, httplimit.IPKeyFunc())
    50  			if err != nil {
    51  				t.Fatal(err)
    52  			}
    53  
    54  			doWork := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    55  				w.WriteHeader(200)
    56  				fmt.Fprintf(w, "hello world")
    57  			})
    58  
    59  			server := httptest.NewServer(middleware.Handle(doWork))
    60  			defer server.Close()
    61  
    62  			client := server.Client()
    63  
    64  			for i := uint64(0); i < tc.tokens; i++ {
    65  				resp, err := client.Get(server.URL)
    66  				if err != nil {
    67  					t.Fatal(err)
    68  				}
    69  
    70  				limit, err := strconv.ParseUint(resp.Header.Get(httplimit.HeaderRateLimitLimit), 10, 64)
    71  				if err != nil {
    72  					t.Fatal(err)
    73  				}
    74  				if got, want := limit, tc.tokens; got != want {
    75  					t.Errorf("limit: expected %d to be %d", got, want)
    76  				}
    77  
    78  				reset, err := time.Parse(time.RFC1123, resp.Header.Get(httplimit.HeaderRateLimitReset))
    79  				if err != nil {
    80  					t.Fatal(err)
    81  				}
    82  				if got, want := time.Until(reset), tc.interval; got > want {
    83  					t.Errorf("reset: expected %d to be less than %d", got, want)
    84  				}
    85  
    86  				remaining, err := strconv.ParseUint(resp.Header.Get(httplimit.HeaderRateLimitRemaining), 10, 64)
    87  				if err != nil {
    88  					t.Fatal(err)
    89  				}
    90  				if got, want := remaining, tc.tokens-uint64(i)-1; got != want {
    91  					t.Errorf("remaining: expected %d to be %d", got, want)
    92  				}
    93  			}
    94  
    95  			// Should be limited
    96  			resp, err := client.Get(server.URL)
    97  			if err != nil {
    98  				t.Fatal(err)
    99  			}
   100  			if got, want := resp.StatusCode, http.StatusTooManyRequests; got != want {
   101  				t.Errorf("expected %d to be %d", got, want)
   102  			}
   103  
   104  			limit, err := strconv.ParseUint(resp.Header.Get(httplimit.HeaderRateLimitLimit), 10, 64)
   105  			if err != nil {
   106  				t.Fatal(err)
   107  			}
   108  			if got, want := limit, tc.tokens; got != want {
   109  				t.Errorf("limit: expected %d to be %d", got, want)
   110  			}
   111  
   112  			reset, err := time.Parse(time.RFC1123, resp.Header.Get(httplimit.HeaderRateLimitReset))
   113  			if err != nil {
   114  				t.Fatal(err)
   115  			}
   116  			if got, want := time.Until(reset), tc.interval; got > want {
   117  				t.Errorf("reset: expected %d to be less than %d", got, want)
   118  			}
   119  
   120  			remaining, err := strconv.ParseUint(resp.Header.Get(httplimit.HeaderRateLimitRemaining), 10, 64)
   121  			if err != nil {
   122  				t.Fatal(err)
   123  			}
   124  			if got, want := remaining, uint64(0); got != want {
   125  				t.Errorf("remaining: expected %d to be %d", got, want)
   126  			}
   127  		})
   128  	}
   129  }