github.com/xzl8028/xenia-server@v0.0.0-20190809101854-18450a97da63/app/ratelimit.go (about)

     1  // Copyright (c) 2018-present Xenia, Inc. All Rights Reserved.
     2  // See License.txt for license information.
     3  
     4  package app
     5  
     6  import (
     7  	"fmt"
     8  	"math"
     9  	"net/http"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/xzl8028/xenia-server/mlog"
    14  	"github.com/xzl8028/xenia-server/model"
    15  	"github.com/xzl8028/xenia-server/utils"
    16  	"github.com/pkg/errors"
    17  	"github.com/throttled/throttled"
    18  	"github.com/throttled/throttled/store/memstore"
    19  )
    20  
    21  type RateLimiter struct {
    22  	throttledRateLimiter *throttled.GCRARateLimiter
    23  	useAuth              bool
    24  	useIP                bool
    25  	header               string
    26  	trustedProxyIPHeader []string
    27  }
    28  
    29  func NewRateLimiter(settings *model.RateLimitSettings, trustedProxyIPHeader []string) (*RateLimiter, error) {
    30  	store, err := memstore.New(*settings.MemoryStoreSize)
    31  	if err != nil {
    32  		return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_memory_store"))
    33  	}
    34  
    35  	quota := throttled.RateQuota{
    36  		MaxRate:  throttled.PerSec(*settings.PerSec),
    37  		MaxBurst: *settings.MaxBurst,
    38  	}
    39  
    40  	throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
    41  	if err != nil {
    42  		return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_rate_limiter"))
    43  	}
    44  
    45  	return &RateLimiter{
    46  		throttledRateLimiter: throttledRateLimiter,
    47  		useAuth:              *settings.VaryByUser,
    48  		useIP:                *settings.VaryByRemoteAddr,
    49  		header:               settings.VaryByHeader,
    50  		trustedProxyIPHeader: trustedProxyIPHeader,
    51  	}, nil
    52  }
    53  
    54  func (rl *RateLimiter) GenerateKey(r *http.Request) string {
    55  	key := ""
    56  
    57  	if rl.useAuth {
    58  		token, tokenLocation := ParseAuthTokenFromRequest(r)
    59  		if tokenLocation != TokenLocationNotFound {
    60  			key += token
    61  		} else if rl.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP
    62  			key += utils.GetIpAddress(r, rl.trustedProxyIPHeader)
    63  		}
    64  	} else if rl.useIP { // Only if Auth based is not enabed do we use a plain IP based
    65  		key += utils.GetIpAddress(r, rl.trustedProxyIPHeader)
    66  	}
    67  
    68  	// Note that most of the time the user won't have to set this because the utils.GetIpAddress above tries the
    69  	// most common headers anyway.
    70  	if rl.header != "" {
    71  		key += strings.ToLower(r.Header.Get(rl.header))
    72  	}
    73  
    74  	return key
    75  }
    76  
    77  func (rl *RateLimiter) RateLimitWriter(key string, w http.ResponseWriter) bool {
    78  	limited, context, err := rl.throttledRateLimiter.RateLimit(key, 1)
    79  	if err != nil {
    80  		mlog.Critical("Internal server error when rate limiting. Rate Limiting broken. Error:" + err.Error())
    81  		return false
    82  	}
    83  
    84  	setRateLimitHeaders(w, context)
    85  
    86  	if limited {
    87  		mlog.Error(fmt.Sprintf("Denied due to throttling settings code=429 key=%v", key))
    88  		http.Error(w, "limit exceeded", 429)
    89  	}
    90  
    91  	return limited
    92  }
    93  
    94  func (rl *RateLimiter) UserIdRateLimit(userId string, w http.ResponseWriter) bool {
    95  	if rl.useAuth {
    96  		if rl.RateLimitWriter(userId, w) {
    97  			return true
    98  		}
    99  	}
   100  	return false
   101  }
   102  
   103  func (rl *RateLimiter) RateLimitHandler(wrappedHandler http.Handler) http.Handler {
   104  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   105  		key := rl.GenerateKey(r)
   106  		limited := rl.RateLimitWriter(key, w)
   107  
   108  		if !limited {
   109  			wrappedHandler.ServeHTTP(w, r)
   110  		}
   111  	})
   112  }
   113  
   114  // Copied from https://github.com/throttled/throttled http.go
   115  func setRateLimitHeaders(w http.ResponseWriter, context throttled.RateLimitResult) {
   116  	if v := context.Limit; v >= 0 {
   117  		w.Header().Add("X-RateLimit-Limit", strconv.Itoa(v))
   118  	}
   119  
   120  	if v := context.Remaining; v >= 0 {
   121  		w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(v))
   122  	}
   123  
   124  	if v := context.ResetAfter; v >= 0 {
   125  		vi := int(math.Ceil(v.Seconds()))
   126  		w.Header().Add("X-RateLimit-Reset", strconv.Itoa(vi))
   127  	}
   128  
   129  	if v := context.RetryAfter; v >= 0 {
   130  		vi := int(math.Ceil(v.Seconds()))
   131  		w.Header().Add("Retry-After", strconv.Itoa(vi))
   132  	}
   133  }