github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/common/ratelimit.go (about) 1 package common 2 3 import ( 4 "errors" 5 "strconv" 6 "sync" 7 "time" 8 ) 9 10 var ErrBadRateLimiter = errors.New("That rate limiter doesn't exist") 11 var ErrExceededRateLimit = errors.New("You're exceeding a rate limit. Please wait a while before trying again.") 12 13 // TODO: Persist rate limits to disk 14 type RateLimiter interface { 15 LimitIP(limit, ip string) error 16 LimitUser(limit string, user int) error 17 } 18 19 type RateData struct { 20 value int 21 floorTime int 22 } 23 24 type RateFence struct { 25 duration int 26 max int 27 } 28 29 // TODO: Optimise this by using something other than a string when possible 30 type RateLimit struct { 31 data map[string][]RateData 32 fences []RateFence 33 34 sync.RWMutex 35 } 36 37 func NewRateLimit(fences []RateFence) *RateLimit { 38 for i, fence := range fences { 39 fences[i].duration = fence.duration * 1000 * 1000 * 1000 40 } 41 return &RateLimit{data: make(map[string][]RateData), fences: fences} 42 } 43 44 func (l *RateLimit) Limit(name string, ltype int) error { 45 l.Lock() 46 defer l.Unlock() 47 48 data, ok := l.data[name] 49 if !ok { 50 data = make([]RateData, len(l.fences)) 51 for i, _ := range data { 52 data[i] = RateData{0, int(time.Now().Unix())} 53 } 54 } 55 56 for i, field := range data { 57 fence := l.fences[i] 58 diff := int(time.Now().Unix()) - field.floorTime 59 60 if diff >= fence.duration { 61 field = RateData{0, int(time.Now().Unix())} 62 data[i] = field 63 } 64 65 if field.value > fence.max { 66 return ErrExceededRateLimit 67 } 68 69 field.value++ 70 data[i] = field 71 } 72 73 return nil 74 } 75 76 type DefaultRateLimiter struct { 77 limits map[string]*RateLimit 78 } 79 80 func NewDefaultRateLimiter() *DefaultRateLimiter { 81 return &DefaultRateLimiter{map[string]*RateLimit{ 82 "register": NewRateLimit([]RateFence{{int(time.Hour / 2), 1}}), 83 }} 84 } 85 86 func (l *DefaultRateLimiter) LimitIP(limit, ip string) error { 87 limiter, ok := l.limits[limit] 88 if !ok { 89 return ErrBadRateLimiter 90 } 91 return limiter.Limit(ip, 0) 92 } 93 94 func (l *DefaultRateLimiter) LimitUser(limit string, user int) error { 95 limiter, ok := l.limits[limit] 96 if !ok { 97 return ErrBadRateLimiter 98 } 99 return limiter.Limit(strconv.Itoa(user), 1) 100 }