github.com/mattermosttest/mattermost-server/v5@v5.0.0-20200917143240-9dfa12e121f9/app/ratelimit.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package app
     5  
     6  import (
     7  	"math"
     8  	"net/http"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/mattermost/mattermost-server/v5/mlog"
    13  	"github.com/mattermost/mattermost-server/v5/model"
    14  	"github.com/mattermost/mattermost-server/v5/utils"
    15  	"github.com/pkg/errors"
    16  	"github.com/throttled/throttled"
    17  	"github.com/throttled/throttled/store/memstore"
    18  )
    19  
    20  type RateLimiter struct {
    21  	throttledRateLimiter *throttled.GCRARateLimiter
    22  	useAuth              bool
    23  	useIP                bool
    24  	header               string
    25  	trustedProxyIPHeader []string
    26  }
    27  
    28  func NewRateLimiter(settings *model.RateLimitSettings, trustedProxyIPHeader []string) (*RateLimiter, error) {
    29  	store, err := memstore.New(*settings.MemoryStoreSize)
    30  	if err != nil {
    31  		return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_memory_store"))
    32  	}
    33  
    34  	quota := throttled.RateQuota{
    35  		MaxRate:  throttled.PerSec(*settings.PerSec),
    36  		MaxBurst: *settings.MaxBurst,
    37  	}
    38  
    39  	throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
    40  	if err != nil {
    41  		return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_rate_limiter"))
    42  	}
    43  
    44  	return &RateLimiter{
    45  		throttledRateLimiter: throttledRateLimiter,
    46  		useAuth:              *settings.VaryByUser,
    47  		useIP:                *settings.VaryByRemoteAddr,
    48  		header:               settings.VaryByHeader,
    49  		trustedProxyIPHeader: trustedProxyIPHeader,
    50  	}, nil
    51  }
    52  
    53  func (rl *RateLimiter) GenerateKey(r *http.Request) string {
    54  	key := ""
    55  
    56  	if rl.useAuth {
    57  		token, tokenLocation := ParseAuthTokenFromRequest(r)
    58  		if tokenLocation != TokenLocationNotFound {
    59  			key += token
    60  		} else if rl.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP
    61  			key += utils.GetIpAddress(r, rl.trustedProxyIPHeader)
    62  		}
    63  	} else if rl.useIP { // Only if Auth based is not enabed do we use a plain IP based
    64  		key += utils.GetIpAddress(r, rl.trustedProxyIPHeader)
    65  	}
    66  
    67  	// Note that most of the time the user won't have to set this because the utils.GetIpAddress above tries the
    68  	// most common headers anyway.
    69  	if rl.header != "" {
    70  		key += strings.ToLower(r.Header.Get(rl.header))
    71  	}
    72  
    73  	return key
    74  }
    75  
    76  func (rl *RateLimiter) RateLimitWriter(key string, w http.ResponseWriter) bool {
    77  	limited, context, err := rl.throttledRateLimiter.RateLimit(key, 1)
    78  	if err != nil {
    79  		mlog.Critical("Internal server error when rate limiting. Rate Limiting broken.", mlog.Err(err))
    80  		return false
    81  	}
    82  
    83  	setRateLimitHeaders(w, context)
    84  
    85  	if limited {
    86  		mlog.Error("Denied due to throttling settings code=429", mlog.String("key", key))
    87  		http.Error(w, "limit exceeded", 429)
    88  	}
    89  
    90  	return limited
    91  }
    92  
    93  func (rl *RateLimiter) UserIdRateLimit(userId string, w http.ResponseWriter) bool {
    94  	if rl.useAuth {
    95  		if rl.RateLimitWriter(userId, w) {
    96  			return true
    97  		}
    98  	}
    99  	return false
   100  }
   101  
   102  func (rl *RateLimiter) RateLimitHandler(wrappedHandler http.Handler) http.Handler {
   103  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   104  		key := rl.GenerateKey(r)
   105  		limited := rl.RateLimitWriter(key, w)
   106  
   107  		if !limited {
   108  			wrappedHandler.ServeHTTP(w, r)
   109  		}
   110  	})
   111  }
   112  
   113  // Copied from https://github.com/throttled/throttled http.go
   114  func setRateLimitHeaders(w http.ResponseWriter, context throttled.RateLimitResult) {
   115  	if v := context.Limit; v >= 0 {
   116  		w.Header().Add("X-RateLimit-Limit", strconv.Itoa(v))
   117  	}
   118  
   119  	if v := context.Remaining; v >= 0 {
   120  		w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(v))
   121  	}
   122  
   123  	if v := context.ResetAfter; v >= 0 {
   124  		vi := int(math.Ceil(v.Seconds()))
   125  		w.Header().Add("X-RateLimit-Reset", strconv.Itoa(vi))
   126  	}
   127  
   128  	if v := context.RetryAfter; v >= 0 {
   129  		vi := int(math.Ceil(v.Seconds()))
   130  		w.Header().Add("Retry-After", strconv.Itoa(vi))
   131  	}
   132  }