github.com/goravel/framework@v1.13.9/http/middleware/throttle_test.go (about)

     1  package middleware
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	nethttp "net/http"
     7  	"strconv"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/mock"
    14  
    15  	cachemocks "github.com/goravel/framework/contracts/cache/mocks"
    16  	configmocks "github.com/goravel/framework/contracts/config/mocks"
    17  	"github.com/goravel/framework/contracts/filesystem"
    18  	contractshttp "github.com/goravel/framework/contracts/http"
    19  	httpmocks "github.com/goravel/framework/contracts/http/mocks"
    20  	"github.com/goravel/framework/contracts/validation"
    21  	"github.com/goravel/framework/http"
    22  	"github.com/goravel/framework/http/limit"
    23  	"github.com/goravel/framework/support/carbon"
    24  )
    25  
    26  func TestThrottle(t *testing.T) {
    27  	var (
    28  		ctx                   *TestContext
    29  		mockCache             *cachemocks.Cache
    30  		mockConfig            *configmocks.Config
    31  		mockRateLimiterFacade *httpmocks.RateLimiter
    32  	)
    33  
    34  	now := carbon.Now()
    35  	carbon.SetTestNow(now)
    36  
    37  	tests := []struct {
    38  		name   string
    39  		setup  func()
    40  		assert func()
    41  	}{
    42  		{
    43  			name: "empty limiter",
    44  			setup: func() {
    45  				mockRateLimiterFacade.On("Limiter", "test").Return(func(ctx contractshttp.Context) []contractshttp.Limit {
    46  					return []contractshttp.Limit{}
    47  				}).Once()
    48  
    49  				assert.NotPanics(t, func() {
    50  					Throttle("test")(ctx)
    51  				})
    52  			},
    53  			assert: func() {
    54  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["X-RateLimit-Reset"])
    55  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["Retry-After"])
    56  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["X-RateLimit-Limit"])
    57  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["X-RateLimit-Remaining"])
    58  			},
    59  		},
    60  		{
    61  			name: "not http limit",
    62  			setup: func() {
    63  				mockRateLimiterFacade.On("Limiter", "test").Return(func(ctx contractshttp.Context) []contractshttp.Limit {
    64  					return []contractshttp.Limit{
    65  						&TestLimit{},
    66  					}
    67  				}).Once()
    68  
    69  				assert.NotPanics(t, func() {
    70  					Throttle("test")(ctx)
    71  				})
    72  			},
    73  			assert: func() {
    74  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["X-RateLimit-Reset"])
    75  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["Retry-After"])
    76  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["X-RateLimit-Limit"])
    77  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["X-RateLimit-Remaining"])
    78  			},
    79  		},
    80  		{
    81  			name: "success when first request",
    82  			setup: func() {
    83  				mockRateLimiterFacade.On("Limiter", "test").Return(func(ctx contractshttp.Context) []contractshttp.Limit {
    84  					return []contractshttp.Limit{
    85  						limit.PerMinute(1),
    86  					}
    87  				}).Once()
    88  				mockConfig.On("GetString", "cache.prefix").Return("goravel").Once()
    89  				mockCache.On("Has", mock.MatchedBy(func(timer string) bool {
    90  					return strings.HasSuffix(timer, ":timer")
    91  				})).Return(false).Once()
    92  				mockCache.On("Put", mock.MatchedBy(func(timer string) bool {
    93  					return strings.HasSuffix(timer, ":timer")
    94  				}), now.Timestamp(), time.Duration(1)*time.Minute).Return(nil).Once()
    95  				mockCache.On("Put", mock.MatchedBy(func(key string) bool {
    96  					return strings.HasPrefix(key, "goravel:throttle:test:")
    97  				}), 1, time.Duration(1)*time.Minute).Return(nil).Once()
    98  
    99  				assert.NotPanics(t, func() {
   100  					Throttle("test")(ctx)
   101  				})
   102  			},
   103  			assert: func() {
   104  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["X-RateLimit-Reset"])
   105  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["Retry-After"])
   106  				assert.Equal(t, "1", ctx.Response().(*TestResponse).Headers["X-RateLimit-Limit"])
   107  				assert.Equal(t, "0", ctx.Response().(*TestResponse).Headers["X-RateLimit-Remaining"])
   108  			},
   109  		},
   110  		{
   111  			name: "error when put timer fail in first request",
   112  			setup: func() {
   113  				mockRateLimiterFacade.On("Limiter", "test").Return(func(ctx contractshttp.Context) []contractshttp.Limit {
   114  					return []contractshttp.Limit{
   115  						limit.PerMinute(1),
   116  					}
   117  				}).Once()
   118  				mockConfig.On("GetString", "cache.prefix").Return("goravel").Once()
   119  				mockCache.On("Has", mock.MatchedBy(func(timer string) bool {
   120  					return strings.HasSuffix(timer, ":timer")
   121  				})).Return(false).Once()
   122  				mockCache.On("Put", mock.MatchedBy(func(timer string) bool {
   123  					return strings.HasSuffix(timer, ":timer")
   124  				}), now.Timestamp(), time.Duration(1)*time.Minute).Return(errors.New("error")).Once()
   125  
   126  				assert.Panics(t, func() {
   127  					Throttle("test")(ctx)
   128  				})
   129  			},
   130  			assert: func() {},
   131  		},
   132  		{
   133  			name: "error when put key fail in first request",
   134  			setup: func() {
   135  				mockRateLimiterFacade.On("Limiter", "test").Return(func(ctx contractshttp.Context) []contractshttp.Limit {
   136  					return []contractshttp.Limit{
   137  						limit.PerMinute(1),
   138  					}
   139  				}).Once()
   140  				mockConfig.On("GetString", "cache.prefix").Return("goravel").Once()
   141  				mockCache.On("Has", mock.MatchedBy(func(timer string) bool {
   142  					return strings.HasSuffix(timer, ":timer")
   143  				})).Return(false).Once()
   144  				mockCache.On("Put", mock.MatchedBy(func(timer string) bool {
   145  					return strings.HasSuffix(timer, ":timer")
   146  				}), now.Timestamp(), time.Duration(1)*time.Minute).Return(nil).Once()
   147  				mockCache.On("Put", mock.MatchedBy(func(key string) bool {
   148  					return strings.HasPrefix(key, "goravel:throttle:test:")
   149  				}), 1, time.Duration(1)*time.Minute).Return(errors.New("error")).Once()
   150  
   151  				assert.Panics(t, func() {
   152  					Throttle("test")(ctx)
   153  				})
   154  			},
   155  			assert: func() {},
   156  		},
   157  		{
   158  			name: "success when not over MaxAttempts",
   159  			setup: func() {
   160  				mockRateLimiterFacade.On("Limiter", "test").Return(func(ctx contractshttp.Context) []contractshttp.Limit {
   161  					return []contractshttp.Limit{
   162  						limit.PerMinute(2),
   163  					}
   164  				}).Once()
   165  				mockConfig.On("GetString", "cache.prefix").Return("goravel").Once()
   166  				mockCache.On("Has", mock.MatchedBy(func(timer string) bool {
   167  					return strings.HasSuffix(timer, ":timer")
   168  				})).Return(true).Once()
   169  				mockCache.On("GetInt", mock.MatchedBy(func(key string) bool {
   170  					return strings.HasPrefix(key, "goravel:throttle:test:")
   171  				}), 0).Return(1).Once()
   172  				mockCache.On("Increment", mock.MatchedBy(func(key string) bool {
   173  					return strings.HasPrefix(key, "goravel:throttle:test:")
   174  				})).Return(2, nil).Once()
   175  
   176  				assert.NotPanics(t, func() {
   177  					Throttle("test")(ctx)
   178  				})
   179  			},
   180  			assert: func() {
   181  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["X-RateLimit-Reset"])
   182  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["Retry-After"])
   183  				assert.Equal(t, "2", ctx.Response().(*TestResponse).Headers["X-RateLimit-Limit"])
   184  				assert.Equal(t, "0", ctx.Response().(*TestResponse).Headers["X-RateLimit-Remaining"])
   185  			},
   186  		},
   187  		{
   188  			name: "success when over MaxAttempts",
   189  			setup: func() {
   190  				mockRateLimiterFacade.On("Limiter", "test").Return(func(ctx contractshttp.Context) []contractshttp.Limit {
   191  					return []contractshttp.Limit{
   192  						limit.PerMinute(2),
   193  					}
   194  				}).Once()
   195  				mockConfig.On("GetString", "cache.prefix").Return("goravel").Once()
   196  				mockCache.On("Has", mock.MatchedBy(func(timer string) bool {
   197  					return strings.HasSuffix(timer, ":timer")
   198  				})).Return(true).Once()
   199  				mockCache.On("GetInt", mock.MatchedBy(func(key string) bool {
   200  					return strings.HasPrefix(key, "goravel:throttle:test:")
   201  				}), 0).Return(2).Once()
   202  				mockCache.On("GetInt", mock.MatchedBy(func(timer string) bool {
   203  					return strings.HasSuffix(timer, ":timer")
   204  				}), 0).Return(int(now.Timestamp())).Once()
   205  
   206  				assert.NotPanics(t, func() {
   207  					Throttle("test")(ctx)
   208  				})
   209  			},
   210  			assert: func() {
   211  				assert.Equal(t, strconv.FormatInt(now.Timestamp()+60, 10), ctx.Response().(*TestResponse).Headers["X-RateLimit-Reset"])
   212  				assert.Equal(t, "60", ctx.Response().(*TestResponse).Headers["Retry-After"])
   213  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["X-RateLimit-Limit"])
   214  				assert.Empty(t, ctx.Response().(*TestResponse).Headers["X-RateLimit-Remaining"])
   215  			},
   216  		},
   217  	}
   218  
   219  	for _, test := range tests {
   220  		t.Run(test.name, func(t *testing.T) {
   221  			ctx = new(TestContext)
   222  			mockCache = cachemocks.NewCache(t)
   223  			mockConfig = configmocks.NewConfig(t)
   224  			mockRateLimiterFacade = httpmocks.NewRateLimiter(t)
   225  			http.CacheFacade = mockCache
   226  			http.ConfigFacade = mockConfig
   227  			http.RateLimiterFacade = mockRateLimiterFacade
   228  			test.setup()
   229  			test.assert()
   230  		})
   231  	}
   232  }
   233  
   234  type TestContext struct {
   235  	response contractshttp.ContextResponse
   236  }
   237  
   238  func (r *TestContext) Deadline() (deadline time.Time, ok bool) {
   239  
   240  	panic("do not need to implement it")
   241  }
   242  
   243  func (r *TestContext) Done() <-chan struct{} {
   244  
   245  	panic("do not need to implement it")
   246  }
   247  
   248  func (r *TestContext) Err() error {
   249  
   250  	panic("do not need to implement it")
   251  }
   252  
   253  func (r *TestContext) Value(key any) any {
   254  
   255  	panic("do not need to implement it")
   256  }
   257  
   258  func (r *TestContext) Context() context.Context {
   259  
   260  	panic("do not need to implement it")
   261  }
   262  
   263  func (r *TestContext) WithValue(key string, value any) {
   264  
   265  	panic("do not need to implement it")
   266  }
   267  
   268  func (r *TestContext) Request() contractshttp.ContextRequest {
   269  	return new(TestRequest)
   270  }
   271  
   272  func (r *TestContext) Response() contractshttp.ContextResponse {
   273  	if r.response == nil {
   274  		r.response = &TestResponse{
   275  			Headers: make(map[string]string),
   276  		}
   277  	}
   278  
   279  	return r.response
   280  }
   281  
   282  type TestRequest struct{}
   283  
   284  func (r *TestRequest) Header(key string, defaultValue ...string) string {
   285  
   286  	panic("do not need to implement it")
   287  }
   288  
   289  func (r *TestRequest) Headers() nethttp.Header {
   290  
   291  	panic("do not need to implement it")
   292  }
   293  
   294  func (r *TestRequest) Method() string {
   295  
   296  	panic("do not need to implement it")
   297  }
   298  
   299  func (r *TestRequest) Path() string {
   300  	return "/test"
   301  }
   302  
   303  func (r *TestRequest) Url() string {
   304  
   305  	panic("do not need to implement it")
   306  }
   307  
   308  func (r *TestRequest) FullUrl() string {
   309  
   310  	panic("do not need to implement it")
   311  }
   312  
   313  func (r *TestRequest) Ip() string {
   314  	return "127.0.0.1"
   315  }
   316  
   317  func (r *TestRequest) Host() string {
   318  
   319  	panic("do not need to implement it")
   320  }
   321  
   322  func (r *TestRequest) All() map[string]any {
   323  
   324  	panic("do not need to implement it")
   325  }
   326  
   327  func (r *TestRequest) Bind(obj any) error {
   328  	panic("do not need to implement it")
   329  }
   330  
   331  func (r *TestRequest) Route(key string) string {
   332  
   333  	panic("do not need to implement it")
   334  }
   335  
   336  func (r *TestRequest) RouteInt(key string) int {
   337  	panic("do not need to implement it")
   338  }
   339  
   340  func (r *TestRequest) RouteInt64(key string) int64 {
   341  
   342  	panic("do not need to implement it")
   343  }
   344  
   345  func (r *TestRequest) Query(key string, defaultValue ...string) string {
   346  
   347  	panic("do not need to implement it")
   348  }
   349  
   350  func (r *TestRequest) QueryInt(key string, defaultValue ...int) int {
   351  
   352  	panic("do not need to implement it")
   353  }
   354  
   355  func (r *TestRequest) QueryInt64(key string, defaultValue ...int64) int64 {
   356  
   357  	panic("do not need to implement it")
   358  }
   359  
   360  func (r *TestRequest) QueryBool(key string, defaultValue ...bool) bool {
   361  
   362  	panic("do not need to implement it")
   363  }
   364  
   365  func (r *TestRequest) QueryArray(key string) []string {
   366  
   367  	panic("do not need to implement it")
   368  }
   369  
   370  func (r *TestRequest) QueryMap(key string) map[string]string {
   371  
   372  	panic("do not need to implement it")
   373  }
   374  
   375  func (r *TestRequest) Queries() map[string]string {
   376  
   377  	panic("do not need to implement it")
   378  }
   379  
   380  func (r *TestRequest) Form(key string, defaultValue ...string) string {
   381  
   382  	panic("do not need to implement it")
   383  }
   384  
   385  func (r *TestRequest) Json(key string, defaultValue ...string) string {
   386  
   387  	panic("do not need to implement it")
   388  }
   389  
   390  func (r *TestRequest) Input(key string, defaultValue ...string) string {
   391  
   392  	panic("do not need to implement it")
   393  }
   394  
   395  func (r *TestRequest) InputArray(key string, defaultValue ...[]string) []string {
   396  
   397  	panic("do not need to implement it")
   398  }
   399  
   400  func (r *TestRequest) InputMap(key string, defaultValue ...map[string]string) map[string]string {
   401  
   402  	panic("do not need to implement it")
   403  }
   404  
   405  func (r *TestRequest) InputInt(key string, defaultValue ...int) int {
   406  
   407  	panic("do not need to implement it")
   408  }
   409  
   410  func (r *TestRequest) InputInt64(key string, defaultValue ...int64) int64 {
   411  
   412  	panic("do not need to implement it")
   413  }
   414  
   415  func (r *TestRequest) InputBool(key string, defaultValue ...bool) bool {
   416  
   417  	panic("do not need to implement it")
   418  }
   419  
   420  func (r *TestRequest) File(name string) (filesystem.File, error) {
   421  
   422  	panic("do not need to implement it")
   423  }
   424  
   425  func (r *TestRequest) AbortWithStatus(code int) {}
   426  
   427  func (r *TestRequest) AbortWithStatusJson(code int, jsonObj any) {
   428  
   429  	panic("do not need to implement it")
   430  }
   431  
   432  func (r *TestRequest) Next() {}
   433  
   434  func (r *TestRequest) Origin() *nethttp.Request {
   435  
   436  	panic("do not need to implement it")
   437  }
   438  
   439  func (r *TestRequest) Validate(rules map[string]string, options ...validation.Option) (validation.Validator, error) {
   440  
   441  	panic("do not need to implement it")
   442  }
   443  
   444  func (r *TestRequest) ValidateRequest(request contractshttp.FormRequest) (validation.Errors, error) {
   445  
   446  	panic("do not need to implement it")
   447  }
   448  
   449  type TestResponse struct {
   450  	Headers map[string]string
   451  }
   452  
   453  func (r *TestResponse) Data(code int, contentType string, data []byte) contractshttp.Response {
   454  	panic("do not need to implement it")
   455  }
   456  
   457  func (r *TestResponse) Download(filepath, filename string) contractshttp.Response {
   458  	panic("do not need to implement it")
   459  }
   460  
   461  func (r *TestResponse) File(filepath string) contractshttp.Response {
   462  	panic("do not need to implement it")
   463  }
   464  
   465  func (r *TestResponse) Header(key, value string) contractshttp.ContextResponse {
   466  	r.Headers[key] = value
   467  
   468  	return r
   469  }
   470  
   471  func (r *TestResponse) Json(code int, obj any) contractshttp.Response {
   472  	panic("do not need to implement it")
   473  }
   474  
   475  func (r *TestResponse) Origin() contractshttp.ResponseOrigin {
   476  	panic("do not need to implement it")
   477  }
   478  
   479  func (r *TestResponse) Redirect(code int, location string) contractshttp.Response {
   480  	panic("do not need to implement it")
   481  }
   482  
   483  func (r *TestResponse) String(code int, format string, values ...any) contractshttp.Response {
   484  	panic("do not need to implement it")
   485  }
   486  
   487  func (r *TestResponse) Success() contractshttp.ResponseSuccess {
   488  	panic("do not need to implement it")
   489  }
   490  
   491  func (r *TestResponse) Status(code int) contractshttp.ResponseStatus {
   492  	panic("do not need to implement it")
   493  }
   494  
   495  func (r *TestResponse) Writer() nethttp.ResponseWriter {
   496  	panic("do not need to implement it")
   497  }
   498  
   499  func (r *TestResponse) Flush() {
   500  	panic("do not need to implement it")
   501  }
   502  
   503  func (r *TestResponse) View() contractshttp.ResponseView {
   504  	panic("do not need to implement it")
   505  }
   506  
   507  type TestLimit struct{}
   508  
   509  func (r *TestLimit) By(key string) contractshttp.Limit {
   510  	panic("do not need to implement it")
   511  }
   512  
   513  func (r *TestLimit) Response(f func(ctx contractshttp.Context)) contractshttp.Limit {
   514  	panic("do not need to implement it")
   515  }