github.com/diamondburned/arikawa/v2@v2.1.0/api/rate/rate.go (about)

     1  package rate
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"strconv"
     7  	"strings"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/diamondburned/arikawa/v2/internal/moreatomic"
    13  	"github.com/pkg/errors"
    14  )
    15  
    16  // ExtraDelay because Discord is trash. I've seen this in both litcord and
    17  // discordgo, with dgo claiming from  experiments.
    18  // RE: Those who want others to fix it for them: release the source code then.
    19  const ExtraDelay = 250 * time.Millisecond
    20  
    21  // ErrTimedOutEarly is the error returned by Limiter.Acquire, if a rate limit
    22  // exceeds the deadline of the context.Context.
    23  var ErrTimedOutEarly = errors.New("rate: rate limit exceeds context deadline")
    24  
    25  // This makes me suicidal.
    26  // https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go
    27  
    28  type Limiter struct {
    29  	// Only 1 per bucket
    30  	CustomLimits []*CustomRateLimit
    31  
    32  	Prefix string
    33  
    34  	// global is a pointer to prevent ARM-compatibility alignment.
    35  	global *int64 // atomic guarded, unixnano
    36  
    37  	bucketMu sync.Mutex
    38  	buckets  map[string]*bucket
    39  }
    40  
    41  type CustomRateLimit struct {
    42  	Contains string
    43  	Reset    time.Duration
    44  }
    45  
    46  type bucket struct {
    47  	lock   moreatomic.CtxMutex
    48  	custom *CustomRateLimit
    49  
    50  	remaining uint64
    51  
    52  	reset     time.Time
    53  	lastReset time.Time // only for custom
    54  }
    55  
    56  func newBucket() *bucket {
    57  	return &bucket{
    58  		lock:      *moreatomic.NewCtxMutex(),
    59  		remaining: 1,
    60  	}
    61  }
    62  
    63  func NewLimiter(prefix string) *Limiter {
    64  	return &Limiter{
    65  		Prefix:       prefix,
    66  		global:       new(int64),
    67  		buckets:      map[string]*bucket{},
    68  		CustomLimits: []*CustomRateLimit{},
    69  	}
    70  }
    71  
    72  func (l *Limiter) getBucket(path string, store bool) *bucket {
    73  	path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix))
    74  
    75  	l.bucketMu.Lock()
    76  	defer l.bucketMu.Unlock()
    77  
    78  	bc, ok := l.buckets[path]
    79  	if !ok && !store {
    80  		return nil
    81  	}
    82  
    83  	if !ok {
    84  		bc := newBucket()
    85  
    86  		for _, limit := range l.CustomLimits {
    87  			if strings.Contains(path, limit.Contains) {
    88  				bc.custom = limit
    89  				break
    90  			}
    91  		}
    92  
    93  		l.buckets[path] = bc
    94  		return bc
    95  	}
    96  
    97  	return bc
    98  }
    99  
   100  // Acquire acquires the rate limiter for the given URL bucket.
   101  func (l *Limiter) Acquire(ctx context.Context, path string) error {
   102  	b := l.getBucket(path, true)
   103  
   104  	if err := b.lock.Lock(ctx); err != nil {
   105  		return err
   106  	}
   107  
   108  	// Deadline until the limiter is released.
   109  	until := time.Time{}
   110  	now := time.Now()
   111  
   112  	if b.remaining == 0 && b.reset.After(now) {
   113  		// out of turns, gotta wait
   114  		until = b.reset
   115  	} else {
   116  		// maybe global rate limit has it
   117  		until = time.Unix(0, atomic.LoadInt64(l.global))
   118  	}
   119  
   120  	if until.After(now) {
   121  		if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
   122  			return ErrTimedOutEarly
   123  		}
   124  
   125  		select {
   126  		case <-ctx.Done():
   127  			b.lock.Unlock()
   128  			return ctx.Err()
   129  		case <-time.After(until.Sub(now)):
   130  		}
   131  	}
   132  
   133  	if b.remaining > 0 {
   134  		b.remaining--
   135  	}
   136  
   137  	return nil
   138  }
   139  
   140  // Release releases the URL from the locks. This doesn't need a context for
   141  // timing out, since it doesn't block that much.
   142  func (l *Limiter) Release(path string, headers http.Header) error {
   143  	b := l.getBucket(path, false)
   144  	if b == nil {
   145  		return nil
   146  	}
   147  
   148  	// TryUnlock because Release may be called when Acquire has not been.
   149  	defer b.lock.TryUnlock()
   150  
   151  	// Check custom limiter
   152  	if b.custom != nil {
   153  		now := time.Now()
   154  
   155  		if now.Sub(b.lastReset) >= b.custom.Reset {
   156  			b.lastReset = now
   157  			b.reset = now.Add(b.custom.Reset)
   158  		}
   159  
   160  		return nil
   161  	}
   162  
   163  	// Check if headers is nil or not:
   164  	if headers == nil {
   165  		return nil
   166  	}
   167  
   168  	var (
   169  		// boolean
   170  		global = headers.Get("X-RateLimit-Global")
   171  
   172  		// seconds
   173  		remaining  = headers.Get("X-RateLimit-Remaining")
   174  		reset      = headers.Get("X-RateLimit-Reset") // float
   175  		retryAfter = headers.Get("Retry-After")
   176  	)
   177  
   178  	switch {
   179  	case retryAfter != "":
   180  		i, err := strconv.Atoi(retryAfter)
   181  		if err != nil {
   182  			return errors.Wrapf(err, "invalid retryAfter %q", retryAfter)
   183  		}
   184  
   185  		at := time.Now().Add(time.Duration(i) * time.Second)
   186  
   187  		if global != "" { // probably "true"
   188  			atomic.StoreInt64(l.global, at.UnixNano())
   189  		} else {
   190  			b.reset = at
   191  		}
   192  
   193  	case reset != "":
   194  		unix, err := strconv.ParseFloat(reset, 64)
   195  		if err != nil {
   196  			return errors.Wrap(err, "invalid reset "+reset)
   197  		}
   198  
   199  		sec := int64(unix)
   200  		nsec := int64((unix - float64(sec)) * float64(time.Second))
   201  
   202  		b.reset = time.Unix(sec, nsec).Add(ExtraDelay)
   203  	}
   204  
   205  	if remaining != "" {
   206  		u, err := strconv.ParseUint(remaining, 10, 64)
   207  		if err != nil {
   208  			return errors.Wrap(err, "invalid remaining "+remaining)
   209  		}
   210  
   211  		b.remaining = u
   212  	}
   213  
   214  	return nil
   215  }