github.com/mvg-fi/go-limiter@v0.1.1/httplimit/middleware_test.go (about) 1 package httplimit_test 2 3 import ( 4 "fmt" 5 "net/http" 6 "net/http/httptest" 7 "strconv" 8 "testing" 9 "time" 10 11 "github.com/mvg-fi/go-limiter/httplimit" 12 "github.com/mvg-fi/go-limiter/memorystore" 13 ) 14 15 func TestNewMiddleware(t *testing.T) { 16 t.Parallel() 17 18 cases := []struct { 19 name string 20 tokens uint64 21 interval time.Duration 22 }{ 23 { 24 name: "millisecond", 25 tokens: 5, 26 interval: 500 * time.Millisecond, 27 }, 28 { 29 name: "second", 30 tokens: 3, 31 interval: time.Second, 32 }, 33 } 34 35 for _, tc := range cases { 36 tc := tc 37 38 t.Run(tc.name, func(t *testing.T) { 39 t.Parallel() 40 41 store, err := memorystore.New(&memorystore.Config{ 42 Tokens: tc.tokens, 43 Interval: tc.interval, 44 }) 45 if err != nil { 46 t.Fatal(err) 47 } 48 49 middleware, err := httplimit.NewMiddleware(store, httplimit.IPKeyFunc()) 50 if err != nil { 51 t.Fatal(err) 52 } 53 54 doWork := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 55 w.WriteHeader(200) 56 fmt.Fprintf(w, "hello world") 57 }) 58 59 server := httptest.NewServer(middleware.Handle(doWork)) 60 defer server.Close() 61 62 client := server.Client() 63 64 for i := uint64(0); i < tc.tokens; i++ { 65 resp, err := client.Get(server.URL) 66 if err != nil { 67 t.Fatal(err) 68 } 69 70 limit, err := strconv.ParseUint(resp.Header.Get(httplimit.HeaderRateLimitLimit), 10, 64) 71 if err != nil { 72 t.Fatal(err) 73 } 74 if got, want := limit, tc.tokens; got != want { 75 t.Errorf("limit: expected %d to be %d", got, want) 76 } 77 78 reset, err := time.Parse(time.RFC1123, resp.Header.Get(httplimit.HeaderRateLimitReset)) 79 if err != nil { 80 t.Fatal(err) 81 } 82 if got, want := time.Until(reset), tc.interval; got > want { 83 t.Errorf("reset: expected %d to be less than %d", got, want) 84 } 85 86 remaining, err := strconv.ParseUint(resp.Header.Get(httplimit.HeaderRateLimitRemaining), 10, 64) 87 if err != nil { 88 t.Fatal(err) 89 } 90 if got, want := remaining, tc.tokens-uint64(i)-1; got != want { 91 t.Errorf("remaining: expected %d to be %d", got, want) 92 } 93 } 94 95 // Should be limited 96 resp, err := client.Get(server.URL) 97 if err != nil { 98 t.Fatal(err) 99 } 100 if got, want := resp.StatusCode, http.StatusTooManyRequests; got != want { 101 t.Errorf("expected %d to be %d", got, want) 102 } 103 104 limit, err := strconv.ParseUint(resp.Header.Get(httplimit.HeaderRateLimitLimit), 10, 64) 105 if err != nil { 106 t.Fatal(err) 107 } 108 if got, want := limit, tc.tokens; got != want { 109 t.Errorf("limit: expected %d to be %d", got, want) 110 } 111 112 reset, err := time.Parse(time.RFC1123, resp.Header.Get(httplimit.HeaderRateLimitReset)) 113 if err != nil { 114 t.Fatal(err) 115 } 116 if got, want := time.Until(reset), tc.interval; got > want { 117 t.Errorf("reset: expected %d to be less than %d", got, want) 118 } 119 120 remaining, err := strconv.ParseUint(resp.Header.Get(httplimit.HeaderRateLimitRemaining), 10, 64) 121 if err != nil { 122 t.Fatal(err) 123 } 124 if got, want := remaining, uint64(0); got != want { 125 t.Errorf("remaining: expected %d to be %d", got, want) 126 } 127 }) 128 } 129 }