github.com/diamondburned/arikawa/v2@v2.1.0/api/rate/rate.go (about) 1 package rate 2 3 import ( 4 "context" 5 "net/http" 6 "strconv" 7 "strings" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/diamondburned/arikawa/v2/internal/moreatomic" 13 "github.com/pkg/errors" 14 ) 15 16 // ExtraDelay because Discord is trash. I've seen this in both litcord and 17 // discordgo, with dgo claiming from experiments. 18 // RE: Those who want others to fix it for them: release the source code then. 19 const ExtraDelay = 250 * time.Millisecond 20 21 // ErrTimedOutEarly is the error returned by Limiter.Acquire, if a rate limit 22 // exceeds the deadline of the context.Context. 23 var ErrTimedOutEarly = errors.New("rate: rate limit exceeds context deadline") 24 25 // This makes me suicidal. 26 // https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go 27 28 type Limiter struct { 29 // Only 1 per bucket 30 CustomLimits []*CustomRateLimit 31 32 Prefix string 33 34 // global is a pointer to prevent ARM-compatibility alignment. 35 global *int64 // atomic guarded, unixnano 36 37 bucketMu sync.Mutex 38 buckets map[string]*bucket 39 } 40 41 type CustomRateLimit struct { 42 Contains string 43 Reset time.Duration 44 } 45 46 type bucket struct { 47 lock moreatomic.CtxMutex 48 custom *CustomRateLimit 49 50 remaining uint64 51 52 reset time.Time 53 lastReset time.Time // only for custom 54 } 55 56 func newBucket() *bucket { 57 return &bucket{ 58 lock: *moreatomic.NewCtxMutex(), 59 remaining: 1, 60 } 61 } 62 63 func NewLimiter(prefix string) *Limiter { 64 return &Limiter{ 65 Prefix: prefix, 66 global: new(int64), 67 buckets: map[string]*bucket{}, 68 CustomLimits: []*CustomRateLimit{}, 69 } 70 } 71 72 func (l *Limiter) getBucket(path string, store bool) *bucket { 73 path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix)) 74 75 l.bucketMu.Lock() 76 defer l.bucketMu.Unlock() 77 78 bc, ok := l.buckets[path] 79 if !ok && !store { 80 return nil 81 } 82 83 if !ok { 84 bc := newBucket() 85 86 for _, limit := range l.CustomLimits { 87 if strings.Contains(path, limit.Contains) { 88 bc.custom = limit 89 break 90 } 91 } 92 93 l.buckets[path] = bc 94 return bc 95 } 96 97 return bc 98 } 99 100 // Acquire acquires the rate limiter for the given URL bucket. 101 func (l *Limiter) Acquire(ctx context.Context, path string) error { 102 b := l.getBucket(path, true) 103 104 if err := b.lock.Lock(ctx); err != nil { 105 return err 106 } 107 108 // Deadline until the limiter is released. 109 until := time.Time{} 110 now := time.Now() 111 112 if b.remaining == 0 && b.reset.After(now) { 113 // out of turns, gotta wait 114 until = b.reset 115 } else { 116 // maybe global rate limit has it 117 until = time.Unix(0, atomic.LoadInt64(l.global)) 118 } 119 120 if until.After(now) { 121 if deadline, ok := ctx.Deadline(); ok && until.After(deadline) { 122 return ErrTimedOutEarly 123 } 124 125 select { 126 case <-ctx.Done(): 127 b.lock.Unlock() 128 return ctx.Err() 129 case <-time.After(until.Sub(now)): 130 } 131 } 132 133 if b.remaining > 0 { 134 b.remaining-- 135 } 136 137 return nil 138 } 139 140 // Release releases the URL from the locks. This doesn't need a context for 141 // timing out, since it doesn't block that much. 142 func (l *Limiter) Release(path string, headers http.Header) error { 143 b := l.getBucket(path, false) 144 if b == nil { 145 return nil 146 } 147 148 // TryUnlock because Release may be called when Acquire has not been. 149 defer b.lock.TryUnlock() 150 151 // Check custom limiter 152 if b.custom != nil { 153 now := time.Now() 154 155 if now.Sub(b.lastReset) >= b.custom.Reset { 156 b.lastReset = now 157 b.reset = now.Add(b.custom.Reset) 158 } 159 160 return nil 161 } 162 163 // Check if headers is nil or not: 164 if headers == nil { 165 return nil 166 } 167 168 var ( 169 // boolean 170 global = headers.Get("X-RateLimit-Global") 171 172 // seconds 173 remaining = headers.Get("X-RateLimit-Remaining") 174 reset = headers.Get("X-RateLimit-Reset") // float 175 retryAfter = headers.Get("Retry-After") 176 ) 177 178 switch { 179 case retryAfter != "": 180 i, err := strconv.Atoi(retryAfter) 181 if err != nil { 182 return errors.Wrapf(err, "invalid retryAfter %q", retryAfter) 183 } 184 185 at := time.Now().Add(time.Duration(i) * time.Second) 186 187 if global != "" { // probably "true" 188 atomic.StoreInt64(l.global, at.UnixNano()) 189 } else { 190 b.reset = at 191 } 192 193 case reset != "": 194 unix, err := strconv.ParseFloat(reset, 64) 195 if err != nil { 196 return errors.Wrap(err, "invalid reset "+reset) 197 } 198 199 sec := int64(unix) 200 nsec := int64((unix - float64(sec)) * float64(time.Second)) 201 202 b.reset = time.Unix(sec, nsec).Add(ExtraDelay) 203 } 204 205 if remaining != "" { 206 u, err := strconv.ParseUint(remaining, 10, 64) 207 if err != nil { 208 return errors.Wrap(err, "invalid remaining "+remaining) 209 } 210 211 b.remaining = u 212 } 213 214 return nil 215 }