github.com/xzl8028/xenia-server@v0.0.0-20190809101854-18450a97da63/app/ratelimit_test.go (about)

     1  // Copyright (c) 2018-present Xenia, Inc. All Rights Reserved.
     2  // See License.txt for license information.
     3  
     4  package app
     5  
     6  import (
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strconv"
    10  	"testing"
    11  
    12  	"github.com/xzl8028/xenia-server/model"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  func genRateLimitSettings(useAuth, useIP bool, header string) *model.RateLimitSettings {
    17  	return &model.RateLimitSettings{
    18  		Enable:           model.NewBool(true),
    19  		PerSec:           model.NewInt(10),
    20  		MaxBurst:         model.NewInt(100),
    21  		MemoryStoreSize:  model.NewInt(10000),
    22  		VaryByRemoteAddr: model.NewBool(useIP),
    23  		VaryByUser:       model.NewBool(useAuth),
    24  		VaryByHeader:     header,
    25  	}
    26  }
    27  
    28  func TestNewRateLimiterSuccess(t *testing.T) {
    29  	settings := genRateLimitSettings(false, false, "")
    30  	rateLimiter, err := NewRateLimiter(settings, nil)
    31  	require.NotNil(t, rateLimiter)
    32  	require.NoError(t, err)
    33  
    34  	rateLimiter, err = NewRateLimiter(settings, []string{"X-Forwarded-For"})
    35  	require.NotNil(t, rateLimiter)
    36  	require.NoError(t, err)
    37  }
    38  
    39  func TestNewRateLimiterFailure(t *testing.T) {
    40  	invalidSettings := genRateLimitSettings(false, false, "")
    41  	invalidSettings.MaxBurst = model.NewInt(-100)
    42  	rateLimiter, err := NewRateLimiter(invalidSettings, nil)
    43  	require.Nil(t, rateLimiter)
    44  	require.Error(t, err)
    45  
    46  	rateLimiter, err = NewRateLimiter(invalidSettings, []string{"X-Forwarded-For", "X-Real-Ip"})
    47  	require.Nil(t, rateLimiter)
    48  	require.Error(t, err)
    49  }
    50  
    51  func TestGenerateKey(t *testing.T) {
    52  	cases := []struct {
    53  		useAuth         bool
    54  		useIP           bool
    55  		header          string
    56  		authTokenResult string
    57  		ipResult        string
    58  		headerResult    string
    59  		expectedKey     string
    60  	}{
    61  		{false, false, "", "", "", "", ""},
    62  		{true, false, "", "resultkey", "notme", "notme", "resultkey"},
    63  		{false, true, "", "notme", "resultkey", "notme", "resultkey"},
    64  		{false, false, "myheader", "notme", "notme", "resultkey", "resultkey"},
    65  		{true, true, "", "resultkey", "ipaddr", "notme", "resultkey"},
    66  		{true, true, "", "", "ipaddr", "notme", "ipaddr"},
    67  		{true, true, "myheader", "resultkey", "ipaddr", "hadd", "resultkeyhadd"},
    68  		{true, true, "myheader", "", "ipaddr", "hadd", "ipaddrhadd"},
    69  	}
    70  
    71  	for testnum, tc := range cases {
    72  		req := httptest.NewRequest("GET", "/", nil)
    73  		if tc.authTokenResult != "" {
    74  			req.AddCookie(&http.Cookie{
    75  				Name:  model.SESSION_COOKIE_TOKEN,
    76  				Value: tc.authTokenResult,
    77  			})
    78  		}
    79  		req.RemoteAddr = tc.ipResult + ":80"
    80  		if tc.headerResult != "" {
    81  			req.Header.Set(tc.header, tc.headerResult)
    82  		}
    83  
    84  		rateLimiter, _ := NewRateLimiter(genRateLimitSettings(tc.useAuth, tc.useIP, tc.header), nil)
    85  
    86  		key := rateLimiter.GenerateKey(req)
    87  
    88  		require.Equal(t, tc.expectedKey, key, "Wrong key on test "+strconv.Itoa(testnum))
    89  	}
    90  }
    91  
    92  func TestGenerateKey_TrustedHeader(t *testing.T) {
    93  	req := httptest.NewRequest("GET", "/", nil)
    94  	req.RemoteAddr = "10.10.10.5:80"
    95  	req.Header.Set("X-Forwarded-For", "10.6.3.1, 10.5.1.2")
    96  
    97  	rateLimiter, _ := NewRateLimiter(genRateLimitSettings(true, true, ""), []string{"X-Forwarded-For"})
    98  	key := rateLimiter.GenerateKey(req)
    99  	require.Equal(t, "10.6.3.1", key, "Wrong key on test with allowed trusted proxy header")
   100  
   101  	rateLimiter, _ = NewRateLimiter(genRateLimitSettings(true, true, ""), nil)
   102  	key = rateLimiter.GenerateKey(req)
   103  	require.Equal(t, "10.10.10.5", key, "Wrong key on test without allowed trusted proxy header")
   104  }