github.com/sethvargo/go-limiter@v1.0.0/memorystore/store.go (about)

     1  // Package memorystore defines an in-memory storage system for limiting.
     2  package memorystore
     3  
     4  import (
     5  	"context"
     6  	"sync"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/sethvargo/go-limiter"
    11  	"github.com/sethvargo/go-limiter/internal/fasttime"
    12  )
    13  
    14  var _ limiter.Store = (*store)(nil)
    15  
    16  type store struct {
    17  	tokens   uint64
    18  	interval time.Duration
    19  
    20  	sweepInterval time.Duration
    21  	sweepMinTTL   uint64
    22  
    23  	data     map[string]*bucket
    24  	dataLock sync.RWMutex
    25  
    26  	stopped uint32
    27  	stopCh  chan struct{}
    28  }
    29  
    30  // Config is used as input to New. It defines the behavior of the storage
    31  // system.
    32  type Config struct {
    33  	// Tokens is the number of tokens to allow per interval. The default value is
    34  	// 1.
    35  	Tokens uint64
    36  
    37  	// Interval is the time interval upon which to enforce rate limiting. The
    38  	// default value is 1 second.
    39  	Interval time.Duration
    40  
    41  	// SweepInterval is the rate at which to run the garabage collection on stale
    42  	// entries. Setting this to a low value will optimize memory consumption, but
    43  	// will likely reduce performance and increase lock contention. Setting this
    44  	// to a high value will maximum throughput, but will increase the memory
    45  	// footprint. This can be tuned in combination with SweepMinTTL to control how
    46  	// long stale entires are kept. The default value is 6 hours.
    47  	SweepInterval time.Duration
    48  
    49  	// SweepMinTTL is the minimum amount of time a session must be inactive before
    50  	// clearing it from the entries. There's no validation, but this should be at
    51  	// least as high as your rate limit, or else the data store will purge records
    52  	// before they limit is applied. The default value is 12 hours.
    53  	SweepMinTTL time.Duration
    54  
    55  	// InitialAlloc is the size to use for the in-memory map. Go will
    56  	// automatically expand the buffer, but choosing higher number can trade
    57  	// memory consumption for performance as it limits the number of times the map
    58  	// needs to expand. The default value is 4096.
    59  	InitialAlloc int
    60  }
    61  
    62  // New creates an in-memory rate limiter that uses a bucketing model to limit
    63  // the number of permitted events over an interval. It's optimized for runtime
    64  // and memory efficiency.
    65  func New(c *Config) (limiter.Store, error) {
    66  	if c == nil {
    67  		c = new(Config)
    68  	}
    69  
    70  	tokens := uint64(1)
    71  	if c.Tokens > 0 {
    72  		tokens = c.Tokens
    73  	}
    74  
    75  	interval := 1 * time.Second
    76  	if c.Interval > 0 {
    77  		interval = c.Interval
    78  	}
    79  
    80  	sweepInterval := 6 * time.Hour
    81  	if c.SweepInterval > 0 {
    82  		sweepInterval = c.SweepInterval
    83  	}
    84  
    85  	sweepMinTTL := 12 * time.Hour
    86  	if c.SweepMinTTL > 0 {
    87  		sweepMinTTL = c.SweepMinTTL
    88  	}
    89  
    90  	initialAlloc := 4096
    91  	if c.InitialAlloc > 0 {
    92  		initialAlloc = c.InitialAlloc
    93  	}
    94  
    95  	s := &store{
    96  		tokens:   tokens,
    97  		interval: interval,
    98  
    99  		sweepInterval: sweepInterval,
   100  		sweepMinTTL:   uint64(sweepMinTTL),
   101  
   102  		data:   make(map[string]*bucket, initialAlloc),
   103  		stopCh: make(chan struct{}),
   104  	}
   105  	go s.purge()
   106  	return s, nil
   107  }
   108  
   109  // Take attempts to remove a token from the named key. If the take is
   110  // successful, it returns true, otherwise false. It also returns the configured
   111  // limit, remaining tokens, and reset time.
   112  func (s *store) Take(ctx context.Context, key string) (uint64, uint64, uint64, bool, error) {
   113  	// If the store is stopped, all requests are rejected.
   114  	if atomic.LoadUint32(&s.stopped) == 1 {
   115  		return 0, 0, 0, false, limiter.ErrStopped
   116  	}
   117  
   118  	// Acquire a read lock first - this allows other to concurrently check limits
   119  	// without taking a full lock.
   120  	s.dataLock.RLock()
   121  	if b, ok := s.data[key]; ok {
   122  		s.dataLock.RUnlock()
   123  		return b.take()
   124  	}
   125  	s.dataLock.RUnlock()
   126  
   127  	// Unfortunately we did not find the key in the map. Take out a full lock. We
   128  	// have to check if the key exists again, because it's possible another
   129  	// goroutine created it between our shared lock and exclusive lock.
   130  	s.dataLock.Lock()
   131  	if b, ok := s.data[key]; ok {
   132  		s.dataLock.Unlock()
   133  		return b.take()
   134  	}
   135  
   136  	// This is the first time we've seen this entry (or it's been garbage
   137  	// collected), so create the bucket and take an initial request.
   138  	b := newBucket(s.tokens, s.interval)
   139  
   140  	// Add it to the map and take.
   141  	s.data[key] = b
   142  	s.dataLock.Unlock()
   143  	return b.take()
   144  }
   145  
   146  // Get retrieves the information about the key, if any exists.
   147  func (s *store) Get(ctx context.Context, key string) (uint64, uint64, error) {
   148  	// If the store is stopped, all requests are rejected.
   149  	if atomic.LoadUint32(&s.stopped) == 1 {
   150  		return 0, 0, limiter.ErrStopped
   151  	}
   152  
   153  	// Acquire a read lock first - this allows other to concurrently check limits
   154  	// without taking a full lock.
   155  	s.dataLock.RLock()
   156  	if b, ok := s.data[key]; ok {
   157  		s.dataLock.RUnlock()
   158  		return b.get()
   159  	}
   160  	s.dataLock.RUnlock()
   161  
   162  	return 0, 0, nil
   163  }
   164  
   165  // Set configures the bucket-specific tokens and interval.
   166  func (s *store) Set(ctx context.Context, key string, tokens uint64, interval time.Duration) error {
   167  	s.dataLock.Lock()
   168  	b := newBucket(tokens, interval)
   169  	s.data[key] = b
   170  	s.dataLock.Unlock()
   171  	return nil
   172  }
   173  
   174  // Burst adds the provided value to the bucket's currently available tokens.
   175  func (s *store) Burst(ctx context.Context, key string, tokens uint64) error {
   176  	s.dataLock.Lock()
   177  	if b, ok := s.data[key]; ok {
   178  		b.lock.Lock()
   179  		s.dataLock.Unlock()
   180  		b.availableTokens = b.availableTokens + tokens
   181  		b.lock.Unlock()
   182  		return nil
   183  	}
   184  
   185  	// If we got this far, there's no current record for the key.
   186  	b := newBucket(s.tokens+tokens, s.interval)
   187  	s.data[key] = b
   188  	s.dataLock.Unlock()
   189  	return nil
   190  }
   191  
   192  // Close stops the memory limiter and cleans up any outstanding
   193  // sessions. You should always call Close() as it releases the memory consumed
   194  // by the map AND releases the tickers.
   195  func (s *store) Close(ctx context.Context) error {
   196  	if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) {
   197  		return nil
   198  	}
   199  
   200  	// Close the channel to prevent future purging.
   201  	close(s.stopCh)
   202  
   203  	// Delete all the things.
   204  	s.dataLock.Lock()
   205  	for k := range s.data {
   206  		delete(s.data, k)
   207  	}
   208  	s.dataLock.Unlock()
   209  	return nil
   210  }
   211  
   212  // purge continually iterates over the map and purges old values on the provided
   213  // sweep interval. Earlier designs used a go-function-per-item expiration, but
   214  // it actually generated *more* lock contention under normal use. The most
   215  // performant option with real-world data was a global garbage collection on a
   216  // fixed interval.
   217  func (s *store) purge() {
   218  	ticker := time.NewTicker(s.sweepInterval)
   219  	defer ticker.Stop()
   220  
   221  	for {
   222  		select {
   223  		case <-s.stopCh:
   224  			return
   225  		case <-ticker.C:
   226  		}
   227  
   228  		s.dataLock.Lock()
   229  		now := fasttime.Now()
   230  		for k, b := range s.data {
   231  			b.lock.Lock()
   232  			lastTime := b.startTime + (b.lastTick * uint64(b.interval))
   233  			b.lock.Unlock()
   234  
   235  			if now-lastTime > s.sweepMinTTL {
   236  				delete(s.data, k)
   237  			}
   238  		}
   239  		s.dataLock.Unlock()
   240  	}
   241  }
   242  
   243  // bucket is an internal wrapper around a taker.
   244  type bucket struct {
   245  	// startTime is the number of nanoseconds from unix epoch when this bucket was
   246  	// initially created.
   247  	startTime uint64
   248  
   249  	// maxTokens is the maximum number of tokens permitted on the bucket at any
   250  	// time. The number of available tokens will never exceed this value.
   251  	maxTokens uint64
   252  
   253  	// interval is the time at which ticking should occur.
   254  	interval time.Duration
   255  
   256  	// availableTokens is the current point-in-time number of tokens remaining.
   257  	availableTokens uint64
   258  
   259  	// lastTick is the last clock tick, used to re-calculate the number of tokens
   260  	// on the bucket.
   261  	lastTick uint64
   262  
   263  	// lock guards the mutable fields.
   264  	lock sync.Mutex
   265  }
   266  
   267  // newBucket creates a new bucket from the given tokens and interval.
   268  func newBucket(tokens uint64, interval time.Duration) *bucket {
   269  	b := &bucket{
   270  		startTime:       fasttime.Now(),
   271  		maxTokens:       tokens,
   272  		availableTokens: tokens,
   273  		interval:        interval,
   274  	}
   275  	return b
   276  }
   277  
   278  // get returns information about the bucket.
   279  func (b *bucket) get() (tokens uint64, remaining uint64, retErr error) {
   280  	b.lock.Lock()
   281  	defer b.lock.Unlock()
   282  
   283  	tokens = b.maxTokens
   284  	remaining = b.availableTokens
   285  	return
   286  }
   287  
   288  // take attempts to remove a token from the bucket. If there are no tokens
   289  // available and the clock has ticked forward, it recalculates the number of
   290  // tokens and retries. It returns the limit, remaining tokens, time until
   291  // refresh, and whether the take was successful.
   292  func (b *bucket) take() (tokens uint64, remaining uint64, reset uint64, ok bool, retErr error) {
   293  	// Capture the current request time, current tick, and amount of time until
   294  	// the bucket resets.
   295  	now := fasttime.Now()
   296  
   297  	b.lock.Lock()
   298  	defer b.lock.Unlock()
   299  
   300  	// If the current time is before the start time, it means the server clock was
   301  	// reset to an earlier time. In that case, rebase to 0.
   302  	if now < b.startTime {
   303  		b.startTime = now
   304  		b.lastTick = 0
   305  	}
   306  
   307  	currTick := tick(b.startTime, now, b.interval)
   308  
   309  	tokens = b.maxTokens
   310  	reset = b.startTime + ((currTick + 1) * uint64(b.interval))
   311  
   312  	// If we're on a new tick since last assessment, perform
   313  	// a full reset up to maxTokens.
   314  	if b.lastTick < currTick {
   315  		b.availableTokens = b.maxTokens
   316  		b.lastTick = currTick
   317  	}
   318  
   319  	if b.availableTokens > 0 {
   320  		b.availableTokens--
   321  		ok = true
   322  		remaining = b.availableTokens
   323  	}
   324  
   325  	return
   326  }
   327  
   328  // tick is the total number of times the current interval has occurred between
   329  // when the time started (start) and the current time (curr). For example, if
   330  // the start time was 12:30pm and it's currently 1:00pm, and the interval was 5
   331  // minutes, tick would return 6 because 1:00pm is the 6th 5-minute tick. Note
   332  // that tick would return 5 at 12:59pm, because it hasn't reached the 6th tick
   333  // yet.
   334  func tick(start, curr uint64, interval time.Duration) uint64 {
   335  	return (curr - start) / uint64(interval.Nanoseconds())
   336  }