github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/common/ratelimit.go (about)

     1  package common
     2  
     3  import (
     4  	"errors"
     5  	"strconv"
     6  	"sync"
     7  	"time"
     8  )
     9  
    10  var ErrBadRateLimiter = errors.New("That rate limiter doesn't exist")
    11  var ErrExceededRateLimit = errors.New("You're exceeding a rate limit. Please wait a while before trying again.")
    12  
    13  // TODO: Persist rate limits to disk
    14  type RateLimiter interface {
    15  	LimitIP(limit, ip string) error
    16  	LimitUser(limit string, user int) error
    17  }
    18  
    19  type RateData struct {
    20  	value     int
    21  	floorTime int
    22  }
    23  
    24  type RateFence struct {
    25  	duration int
    26  	max      int
    27  }
    28  
    29  // TODO: Optimise this by using something other than a string when possible
    30  type RateLimit struct {
    31  	data   map[string][]RateData
    32  	fences []RateFence
    33  
    34  	sync.RWMutex
    35  }
    36  
    37  func NewRateLimit(fences []RateFence) *RateLimit {
    38  	for i, fence := range fences {
    39  		fences[i].duration = fence.duration * 1000 * 1000 * 1000
    40  	}
    41  	return &RateLimit{data: make(map[string][]RateData), fences: fences}
    42  }
    43  
    44  func (l *RateLimit) Limit(name string, ltype int) error {
    45  	l.Lock()
    46  	defer l.Unlock()
    47  
    48  	data, ok := l.data[name]
    49  	if !ok {
    50  		data = make([]RateData, len(l.fences))
    51  		for i, _ := range data {
    52  			data[i] = RateData{0, int(time.Now().Unix())}
    53  		}
    54  	}
    55  
    56  	for i, field := range data {
    57  		fence := l.fences[i]
    58  		diff := int(time.Now().Unix()) - field.floorTime
    59  
    60  		if diff >= fence.duration {
    61  			field = RateData{0, int(time.Now().Unix())}
    62  			data[i] = field
    63  		}
    64  
    65  		if field.value > fence.max {
    66  			return ErrExceededRateLimit
    67  		}
    68  
    69  		field.value++
    70  		data[i] = field
    71  	}
    72  
    73  	return nil
    74  }
    75  
    76  type DefaultRateLimiter struct {
    77  	limits map[string]*RateLimit
    78  }
    79  
    80  func NewDefaultRateLimiter() *DefaultRateLimiter {
    81  	return &DefaultRateLimiter{map[string]*RateLimit{
    82  		"register": NewRateLimit([]RateFence{{int(time.Hour / 2), 1}}),
    83  	}}
    84  }
    85  
    86  func (l *DefaultRateLimiter) LimitIP(limit, ip string) error {
    87  	limiter, ok := l.limits[limit]
    88  	if !ok {
    89  		return ErrBadRateLimiter
    90  	}
    91  	return limiter.Limit(ip, 0)
    92  }
    93  
    94  func (l *DefaultRateLimiter) LimitUser(limit string, user int) error {
    95  	limiter, ok := l.limits[limit]
    96  	if !ok {
    97  		return ErrBadRateLimiter
    98  	}
    99  	return limiter.Limit(strconv.Itoa(user), 1)
   100  }