github.com/jlevesy/mattermost-server@v5.3.2-0.20181003190404-7468f35cb0c8+incompatible/app/ratelimit.go (about)

     1  // Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved.
     2  // See License.txt for license information.
     3  
     4  package app
     5  
     6  import (
     7  	"fmt"
     8  	"math"
     9  	"net/http"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/mattermost/mattermost-server/mlog"
    14  	"github.com/mattermost/mattermost-server/model"
    15  	"github.com/mattermost/mattermost-server/utils"
    16  	"github.com/pkg/errors"
    17  	"github.com/throttled/throttled"
    18  	"github.com/throttled/throttled/store/memstore"
    19  )
    20  
    21  type RateLimiter struct {
    22  	throttledRateLimiter *throttled.GCRARateLimiter
    23  	useAuth              bool
    24  	useIP                bool
    25  	header               string
    26  }
    27  
    28  func NewRateLimiter(settings *model.RateLimitSettings) (*RateLimiter, error) {
    29  	store, err := memstore.New(*settings.MemoryStoreSize)
    30  	if err != nil {
    31  		return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_memory_store"))
    32  	}
    33  
    34  	quota := throttled.RateQuota{
    35  		MaxRate:  throttled.PerSec(*settings.PerSec),
    36  		MaxBurst: *settings.MaxBurst,
    37  	}
    38  
    39  	throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
    40  	if err != nil {
    41  		return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_rate_limiter"))
    42  	}
    43  
    44  	return &RateLimiter{
    45  		throttledRateLimiter: throttledRateLimiter,
    46  		useAuth:              *settings.VaryByUser,
    47  		useIP:                *settings.VaryByRemoteAddr,
    48  		header:               settings.VaryByHeader,
    49  	}, nil
    50  }
    51  
    52  func (rl *RateLimiter) GenerateKey(r *http.Request) string {
    53  	key := ""
    54  
    55  	if rl.useAuth {
    56  		token, tokenLocation := ParseAuthTokenFromRequest(r)
    57  		if tokenLocation != TokenLocationNotFound {
    58  			key += token
    59  		} else if rl.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP
    60  			key += utils.GetIpAddress(r)
    61  		}
    62  	} else if rl.useIP { // Only if Auth based is not enabed do we use a plain IP based
    63  		key += utils.GetIpAddress(r)
    64  	}
    65  
    66  	// Note that most of the time the user won't have to set this because the utils.GetIpAddress above tries the
    67  	// most common headers anyway.
    68  	if rl.header != "" {
    69  		key += strings.ToLower(r.Header.Get(rl.header))
    70  	}
    71  
    72  	return key
    73  }
    74  
    75  func (rl *RateLimiter) RateLimitWriter(key string, w http.ResponseWriter) bool {
    76  	limited, context, err := rl.throttledRateLimiter.RateLimit(key, 1)
    77  	if err != nil {
    78  		mlog.Critical("Internal server error when rate limiting. Rate Limiting broken. Error:" + err.Error())
    79  		return false
    80  	}
    81  
    82  	setRateLimitHeaders(w, context)
    83  
    84  	if limited {
    85  		mlog.Error(fmt.Sprintf("Denied due to throttling settings code=429 key=%v", key))
    86  		http.Error(w, "limit exceeded", 429)
    87  	}
    88  
    89  	return limited
    90  }
    91  
    92  func (rl *RateLimiter) UserIdRateLimit(userId string, w http.ResponseWriter) bool {
    93  	if rl.useAuth {
    94  		if rl.RateLimitWriter(userId, w) {
    95  			return true
    96  		}
    97  	}
    98  	return false
    99  }
   100  
   101  func (rl *RateLimiter) RateLimitHandler(wrappedHandler http.Handler) http.Handler {
   102  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   103  		key := rl.GenerateKey(r)
   104  		limited := rl.RateLimitWriter(key, w)
   105  
   106  		if !limited {
   107  			wrappedHandler.ServeHTTP(w, r)
   108  		}
   109  	})
   110  }
   111  
   112  // Copied from https://github.com/throttled/throttled http.go
   113  func setRateLimitHeaders(w http.ResponseWriter, context throttled.RateLimitResult) {
   114  	if v := context.Limit; v >= 0 {
   115  		w.Header().Add("X-RateLimit-Limit", strconv.Itoa(v))
   116  	}
   117  
   118  	if v := context.Remaining; v >= 0 {
   119  		w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(v))
   120  	}
   121  
   122  	if v := context.ResetAfter; v >= 0 {
   123  		vi := int(math.Ceil(v.Seconds()))
   124  		w.Header().Add("X-RateLimit-Reset", strconv.Itoa(vi))
   125  	}
   126  
   127  	if v := context.RetryAfter; v >= 0 {
   128  		vi := int(math.Ceil(v.Seconds()))
   129  		w.Header().Add("Retry-After", strconv.Itoa(vi))
   130  	}
   131  }