github.com/diamondburned/arikawa@v1.3.14/api/rate/rate.go (about)

     1  package rate
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"strconv"
     7  	"strings"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/pkg/errors"
    13  
    14  	"github.com/diamondburned/arikawa/internal/moreatomic"
    15  )
    16  
    17  // ExtraDelay because Discord is trash. I've seen this in both litcord and
    18  // discordgo, with dgo claiming from his experiments.
    19  // RE: Those who want others to fix it for them: release the source code then.
    20  const ExtraDelay = 250 * time.Millisecond
    21  
    22  // This makes me suicidal.
    23  // https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go
    24  
    25  type Limiter struct {
    26  	// Only 1 per bucket
    27  	CustomLimits []*CustomRateLimit
    28  
    29  	Prefix string
    30  
    31  	global  *int64 // atomic guarded, unixnano
    32  	buckets sync.Map
    33  }
    34  
    35  type CustomRateLimit struct {
    36  	Contains string
    37  	Reset    time.Duration
    38  }
    39  
    40  type bucket struct {
    41  	lock   moreatomic.CtxMutex
    42  	custom *CustomRateLimit
    43  
    44  	remaining uint64
    45  
    46  	reset     time.Time
    47  	lastReset time.Time // only for custom
    48  }
    49  
    50  func newBucket() *bucket {
    51  	return &bucket{
    52  		lock:      *moreatomic.NewCtxMutex(),
    53  		remaining: 1,
    54  	}
    55  }
    56  
    57  func NewLimiter(prefix string) *Limiter {
    58  	return &Limiter{
    59  		Prefix:       prefix,
    60  		global:       new(int64),
    61  		buckets:      sync.Map{},
    62  		CustomLimits: []*CustomRateLimit{},
    63  	}
    64  }
    65  
    66  func (l *Limiter) getBucket(path string, store bool) *bucket {
    67  	path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix))
    68  
    69  	bc, ok := l.buckets.Load(path)
    70  	if !ok && !store {
    71  		return nil
    72  	}
    73  
    74  	if !ok {
    75  		bc := newBucket()
    76  
    77  		for _, limit := range l.CustomLimits {
    78  			if strings.Contains(path, limit.Contains) {
    79  				bc.custom = limit
    80  				break
    81  			}
    82  		}
    83  
    84  		l.buckets.Store(path, bc)
    85  		return bc
    86  	}
    87  
    88  	return bc.(*bucket)
    89  }
    90  
    91  func (l *Limiter) Acquire(ctx context.Context, path string) error {
    92  	b := l.getBucket(path, true)
    93  
    94  	if err := b.lock.Lock(ctx); err != nil {
    95  		return err
    96  	}
    97  
    98  	// Time to sleep
    99  	var sleep time.Duration
   100  
   101  	if b.remaining == 0 && b.reset.After(time.Now()) {
   102  		// out of turns, gotta wait
   103  		sleep = time.Until(b.reset)
   104  	} else {
   105  		// maybe global rate limit has it
   106  		now := time.Now()
   107  		until := time.Unix(0, atomic.LoadInt64(l.global))
   108  
   109  		if until.After(now) {
   110  			sleep = until.Sub(now)
   111  		}
   112  	}
   113  
   114  	if sleep > 0 {
   115  		select {
   116  		case <-ctx.Done():
   117  			b.lock.Unlock()
   118  			return ctx.Err()
   119  		case <-time.After(sleep):
   120  		}
   121  	}
   122  
   123  	if b.remaining > 0 {
   124  		b.remaining--
   125  	}
   126  
   127  	return nil
   128  }
   129  
   130  // Release releases the URL from the locks. This doesn't need a context for
   131  // timing out, since it doesn't block that much.
   132  func (l *Limiter) Release(path string, headers http.Header) error {
   133  	b := l.getBucket(path, false)
   134  	if b == nil {
   135  		return nil
   136  	}
   137  
   138  	// TryUnlock because Release may be called when Acquire has not been.
   139  	defer b.lock.TryUnlock()
   140  
   141  	// Check custom limiter
   142  	if b.custom != nil {
   143  		now := time.Now()
   144  
   145  		if now.Sub(b.lastReset) >= b.custom.Reset {
   146  			b.lastReset = now
   147  			b.reset = now.Add(b.custom.Reset)
   148  		}
   149  
   150  		return nil
   151  	}
   152  
   153  	// Check if headers is nil or not:
   154  	if headers == nil {
   155  		return nil
   156  	}
   157  
   158  	var (
   159  		// boolean
   160  		global = headers.Get("X-RateLimit-Global")
   161  
   162  		// seconds
   163  		remaining  = headers.Get("X-RateLimit-Remaining")
   164  		reset      = headers.Get("X-RateLimit-Reset")
   165  		retryAfter = headers.Get("Retry-After")
   166  	)
   167  
   168  	switch {
   169  	case retryAfter != "":
   170  		i, err := strconv.Atoi(retryAfter)
   171  		if err != nil {
   172  			return errors.Wrap(err, "invalid retryAfter "+retryAfter)
   173  		}
   174  
   175  		at := time.Now().Add(time.Duration(i) * time.Millisecond)
   176  
   177  		if global != "" { // probably true
   178  			atomic.StoreInt64(l.global, at.UnixNano())
   179  		} else {
   180  			b.reset = at
   181  		}
   182  
   183  	case reset != "":
   184  		unix, err := strconv.ParseFloat(reset, 64)
   185  		if err != nil {
   186  			return errors.Wrap(err, "invalid reset "+reset)
   187  		}
   188  
   189  		b.reset = time.Unix(0, int64(unix*float64(time.Second))).
   190  			Add(ExtraDelay)
   191  	}
   192  
   193  	if remaining != "" {
   194  		u, err := strconv.ParseUint(remaining, 10, 64)
   195  		if err != nil {
   196  			return errors.Wrap(err, "invalid remaining "+remaining)
   197  		}
   198  
   199  		b.remaining = u
   200  	}
   201  
   202  	return nil
   203  }