github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/pkg/limits/rate_limiting_test.go (about)

     1  package limits
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/cozy/cozy-stack/pkg/prefixer"
     7  	"github.com/redis/go-redis/v9"
     8  	"github.com/stretchr/testify/require"
     9  )
    10  
    11  func TestRate(t *testing.T) {
    12  	var testInstance = prefixer.NewPrefixer(0, "cozy.example.net", "cozy-example-limits")
    13  
    14  	rOpt, err := redis.ParseURL("redis://localhost:6379/0")
    15  	require.NoError(t, err)
    16  
    17  	redisClient := redis.NewClient(rOpt)
    18  
    19  	tests := []struct {
    20  		Name      string
    21  		Client    Counter
    22  		NeedRedis bool
    23  	}{
    24  		{
    25  			Name:      "InMemory",
    26  			Client:    NewInMemory(),
    27  			NeedRedis: false,
    28  		},
    29  		{
    30  			Name:      "Redis",
    31  			Client:    NewRedis(redisClient),
    32  			NeedRedis: true,
    33  		},
    34  	}
    35  
    36  	for _, test := range tests {
    37  		t.Run(test.Name, func(t *testing.T) {
    38  			if test.NeedRedis && testing.Short() {
    39  				t.Skip("a redis is required for this test: test skipped due to the use of --short flag")
    40  			}
    41  
    42  			limiter := &RateLimiter{counter: test.Client}
    43  
    44  			t.Run("LoginRateNotExceeded", func(t *testing.T) {
    45  				require.NoError(t, limiter.CheckRateLimit(testInstance, AuthType))
    46  			})
    47  
    48  			t.Run("LoginRateExceeded", func(t *testing.T) {
    49  				// Take into account the call above
    50  				for i := 1; i < 1000; i++ {
    51  					require.NoError(t, limiter.CheckRateLimit(testInstance, AuthType))
    52  				}
    53  				err := limiter.CheckRateLimit(testInstance, AuthType)
    54  				require.Error(t, err)
    55  			})
    56  
    57  			t.Run("2FAGenerationNotExceeded", func(t *testing.T) {
    58  				require.NoError(t, limiter.CheckRateLimit(testInstance, TwoFactorGenerationType))
    59  			})
    60  
    61  			t.Run("2FAGenerationExceeded", func(t *testing.T) {
    62  				// Take into account the call above
    63  				for i := 1; i < 20; i++ {
    64  					require.NoError(t, limiter.CheckRateLimit(testInstance, TwoFactorGenerationType))
    65  				}
    66  
    67  				err := limiter.CheckRateLimit(testInstance, TwoFactorGenerationType)
    68  				require.Error(t, err)
    69  			})
    70  
    71  			t.Run("2FARateExceededNotExceeded", func(t *testing.T) {
    72  				require.NoError(t, limiter.CheckRateLimit(testInstance, TwoFactorType))
    73  			})
    74  
    75  			t.Run("2FARateExceeded", func(t *testing.T) {
    76  				// Take into account the call above
    77  				for i := 1; i < 10; i++ {
    78  					require.NoError(t, limiter.CheckRateLimit(testInstance, TwoFactorType))
    79  				}
    80  
    81  				err := limiter.CheckRateLimit(testInstance, TwoFactorType)
    82  				require.Error(t, err)
    83  			})
    84  		})
    85  	}
    86  }