github.com/dschalla/mattermost-server@v4.8.1-rc1+incompatible/app/ratelimit.go (about)

     1  // Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved.
     2  // See License.txt for license information.
     3  
     4  package app
     5  
     6  import (
     7  	"math"
     8  	"net/http"
     9  	"strconv"
    10  	"strings"
    11  
    12  	l4g "github.com/alecthomas/log4go"
    13  	"github.com/mattermost/mattermost-server/model"
    14  	"github.com/mattermost/mattermost-server/utils"
    15  	"github.com/pkg/errors"
    16  	throttled "gopkg.in/throttled/throttled.v2"
    17  	"gopkg.in/throttled/throttled.v2/store/memstore"
    18  )
    19  
    20  type RateLimiter struct {
    21  	throttledRateLimiter *throttled.GCRARateLimiter
    22  	useAuth              bool
    23  	useIP                bool
    24  	header               string
    25  }
    26  
    27  func NewRateLimiter(settings *model.RateLimitSettings) (*RateLimiter, error) {
    28  	store, err := memstore.New(*settings.MemoryStoreSize)
    29  	if err != nil {
    30  		return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_memory_store"))
    31  	}
    32  
    33  	quota := throttled.RateQuota{
    34  		MaxRate:  throttled.PerSec(*settings.PerSec),
    35  		MaxBurst: *settings.MaxBurst,
    36  	}
    37  
    38  	throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
    39  	if err != nil {
    40  		return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_rate_limiter"))
    41  	}
    42  
    43  	return &RateLimiter{
    44  		throttledRateLimiter: throttledRateLimiter,
    45  		useAuth:              *settings.VaryByUser,
    46  		useIP:                *settings.VaryByRemoteAddr,
    47  		header:               settings.VaryByHeader,
    48  	}, nil
    49  }
    50  
    51  func (rl *RateLimiter) GenerateKey(r *http.Request) string {
    52  	key := ""
    53  
    54  	if rl.useAuth {
    55  		token, tokenLocation := ParseAuthTokenFromRequest(r)
    56  		if tokenLocation != TokenLocationNotFound {
    57  			key += token
    58  		} else if rl.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP
    59  			key += utils.GetIpAddress(r)
    60  		}
    61  	} else if rl.useIP { // Only if Auth based is not enabed do we use a plain IP based
    62  		key += utils.GetIpAddress(r)
    63  	}
    64  
    65  	// Note that most of the time the user won't have to set this because the utils.GetIpAddress above tries the
    66  	// most common headers anyway.
    67  	if rl.header != "" {
    68  		key += strings.ToLower(r.Header.Get(rl.header))
    69  	}
    70  
    71  	return key
    72  }
    73  
    74  func (rl *RateLimiter) RateLimitWriter(key string, w http.ResponseWriter) bool {
    75  	limited, context, err := rl.throttledRateLimiter.RateLimit(key, 1)
    76  	if err != nil {
    77  		l4g.Critical("Internal server error when rate limiting. Rate Limiting broken. Error:" + err.Error())
    78  		return false
    79  	}
    80  
    81  	setRateLimitHeaders(w, context)
    82  
    83  	if limited {
    84  		l4g.Error("Denied due to throttling settings code=429 key=%v", key)
    85  		http.Error(w, "limit exceeded", 429)
    86  	}
    87  
    88  	return limited
    89  }
    90  
    91  func (rl *RateLimiter) UserIdRateLimit(userId string, w http.ResponseWriter) bool {
    92  	if rl.useAuth {
    93  		if rl.RateLimitWriter(userId, w) {
    94  			return true
    95  		}
    96  	}
    97  	return false
    98  }
    99  
   100  func (rl *RateLimiter) RateLimitHandler(wrappedHandler http.Handler) http.Handler {
   101  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   102  		key := rl.GenerateKey(r)
   103  		limited := rl.RateLimitWriter(key, w)
   104  
   105  		if !limited {
   106  			wrappedHandler.ServeHTTP(w, r)
   107  		}
   108  	})
   109  }
   110  
   111  // Copied from https://github.com/throttled/throttled http.go
   112  func setRateLimitHeaders(w http.ResponseWriter, context throttled.RateLimitResult) {
   113  	if v := context.Limit; v >= 0 {
   114  		w.Header().Add("X-RateLimit-Limit", strconv.Itoa(v))
   115  	}
   116  
   117  	if v := context.Remaining; v >= 0 {
   118  		w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(v))
   119  	}
   120  
   121  	if v := context.ResetAfter; v >= 0 {
   122  		vi := int(math.Ceil(v.Seconds()))
   123  		w.Header().Add("X-RateLimit-Reset", strconv.Itoa(vi))
   124  	}
   125  
   126  	if v := context.RetryAfter; v >= 0 {
   127  		vi := int(math.Ceil(v.Seconds()))
   128  		w.Header().Add("Retry-After", strconv.Itoa(vi))
   129  	}
   130  }