github.com/diamondburned/arikawa@v1.3.14/api/rate/rate.go (about) 1 package rate 2 3 import ( 4 "context" 5 "net/http" 6 "strconv" 7 "strings" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/pkg/errors" 13 14 "github.com/diamondburned/arikawa/internal/moreatomic" 15 ) 16 17 // ExtraDelay because Discord is trash. I've seen this in both litcord and 18 // discordgo, with dgo claiming from his experiments. 19 // RE: Those who want others to fix it for them: release the source code then. 20 const ExtraDelay = 250 * time.Millisecond 21 22 // This makes me suicidal. 23 // https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go 24 25 type Limiter struct { 26 // Only 1 per bucket 27 CustomLimits []*CustomRateLimit 28 29 Prefix string 30 31 global *int64 // atomic guarded, unixnano 32 buckets sync.Map 33 } 34 35 type CustomRateLimit struct { 36 Contains string 37 Reset time.Duration 38 } 39 40 type bucket struct { 41 lock moreatomic.CtxMutex 42 custom *CustomRateLimit 43 44 remaining uint64 45 46 reset time.Time 47 lastReset time.Time // only for custom 48 } 49 50 func newBucket() *bucket { 51 return &bucket{ 52 lock: *moreatomic.NewCtxMutex(), 53 remaining: 1, 54 } 55 } 56 57 func NewLimiter(prefix string) *Limiter { 58 return &Limiter{ 59 Prefix: prefix, 60 global: new(int64), 61 buckets: sync.Map{}, 62 CustomLimits: []*CustomRateLimit{}, 63 } 64 } 65 66 func (l *Limiter) getBucket(path string, store bool) *bucket { 67 path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix)) 68 69 bc, ok := l.buckets.Load(path) 70 if !ok && !store { 71 return nil 72 } 73 74 if !ok { 75 bc := newBucket() 76 77 for _, limit := range l.CustomLimits { 78 if strings.Contains(path, limit.Contains) { 79 bc.custom = limit 80 break 81 } 82 } 83 84 l.buckets.Store(path, bc) 85 return bc 86 } 87 88 return bc.(*bucket) 89 } 90 91 func (l *Limiter) Acquire(ctx context.Context, path string) error { 92 b := l.getBucket(path, true) 93 94 if err := b.lock.Lock(ctx); err != nil { 95 return err 96 } 97 98 // Time to sleep 99 var sleep time.Duration 100 101 if b.remaining == 0 && b.reset.After(time.Now()) { 102 // out of turns, gotta wait 103 sleep = time.Until(b.reset) 104 } else { 105 // maybe global rate limit has it 106 now := time.Now() 107 until := time.Unix(0, atomic.LoadInt64(l.global)) 108 109 if until.After(now) { 110 sleep = until.Sub(now) 111 } 112 } 113 114 if sleep > 0 { 115 select { 116 case <-ctx.Done(): 117 b.lock.Unlock() 118 return ctx.Err() 119 case <-time.After(sleep): 120 } 121 } 122 123 if b.remaining > 0 { 124 b.remaining-- 125 } 126 127 return nil 128 } 129 130 // Release releases the URL from the locks. This doesn't need a context for 131 // timing out, since it doesn't block that much. 132 func (l *Limiter) Release(path string, headers http.Header) error { 133 b := l.getBucket(path, false) 134 if b == nil { 135 return nil 136 } 137 138 // TryUnlock because Release may be called when Acquire has not been. 139 defer b.lock.TryUnlock() 140 141 // Check custom limiter 142 if b.custom != nil { 143 now := time.Now() 144 145 if now.Sub(b.lastReset) >= b.custom.Reset { 146 b.lastReset = now 147 b.reset = now.Add(b.custom.Reset) 148 } 149 150 return nil 151 } 152 153 // Check if headers is nil or not: 154 if headers == nil { 155 return nil 156 } 157 158 var ( 159 // boolean 160 global = headers.Get("X-RateLimit-Global") 161 162 // seconds 163 remaining = headers.Get("X-RateLimit-Remaining") 164 reset = headers.Get("X-RateLimit-Reset") 165 retryAfter = headers.Get("Retry-After") 166 ) 167 168 switch { 169 case retryAfter != "": 170 i, err := strconv.Atoi(retryAfter) 171 if err != nil { 172 return errors.Wrap(err, "invalid retryAfter "+retryAfter) 173 } 174 175 at := time.Now().Add(time.Duration(i) * time.Millisecond) 176 177 if global != "" { // probably true 178 atomic.StoreInt64(l.global, at.UnixNano()) 179 } else { 180 b.reset = at 181 } 182 183 case reset != "": 184 unix, err := strconv.ParseFloat(reset, 64) 185 if err != nil { 186 return errors.Wrap(err, "invalid reset "+reset) 187 } 188 189 b.reset = time.Unix(0, int64(unix*float64(time.Second))). 190 Add(ExtraDelay) 191 } 192 193 if remaining != "" { 194 u, err := strconv.ParseUint(remaining, 10, 64) 195 if err != nil { 196 return errors.Wrap(err, "invalid remaining "+remaining) 197 } 198 199 b.remaining = u 200 } 201 202 return nil 203 }