gitlab.com/infor-cloud/martian-cloud/tharsis/go-limiter@v0.0.0-20230411193226-3247984d5abc/memorystore/store.go (about)

     1  // Package memorystore defines an in-memory storage system for limiting.
     2  package memorystore
     3  
     4  import (
     5  	"context"
     6  	"sync"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"gitlab.com/infor-cloud/martian-cloud/tharsis/go-limiter"
    11  	"gitlab.com/infor-cloud/martian-cloud/tharsis/go-limiter/fasttime"
    12  )
    13  
    14  var _ limiter.Store = (*store)(nil)
    15  
    16  type store struct {
    17  	tokens   uint64
    18  	interval time.Duration
    19  
    20  	sweepInterval time.Duration
    21  	sweepMinTTL   uint64
    22  
    23  	data     map[string]*bucket
    24  	dataLock sync.RWMutex
    25  
    26  	stopped uint32
    27  	stopCh  chan struct{}
    28  }
    29  
    30  // Config is used as input to New. It defines the behavior of the storage
    31  // system.
    32  type Config struct {
    33  	// Tokens is the number of tokens to allow per interval. The default value is
    34  	// 1.
    35  	Tokens uint64
    36  
    37  	// Interval is the time interval upon which to enforce rate limiting. The
    38  	// default value is 1 second.
    39  	Interval time.Duration
    40  
    41  	// SweepInterval is the rate at which to run the garabage collection on stale
    42  	// entries. Setting this to a low value will optimize memory consumption, but
    43  	// will likely reduce performance and increase lock contention. Setting this
    44  	// to a high value will maximum throughput, but will increase the memory
    45  	// footprint. This can be tuned in combination with SweepMinTTL to control how
    46  	// long stale entires are kept. The default value is 6 hours.
    47  	SweepInterval time.Duration
    48  
    49  	// SweepMinTTL is the minimum amount of time a session must be inactive before
    50  	// clearing it from the entries. There's no validation, but this should be at
    51  	// least as high as your rate limit, or else the data store will purge records
    52  	// before they limit is applied. The default value is 12 hours.
    53  	SweepMinTTL time.Duration
    54  
    55  	// InitialAlloc is the size to use for the in-memory map. Go will
    56  	// automatically expand the buffer, but choosing higher number can trade
    57  	// memory consumption for performance as it limits the number of times the map
    58  	// needs to expand. The default value is 4096.
    59  	InitialAlloc int
    60  }
    61  
    62  // New creates an in-memory rate limiter that uses a bucketing model to limit
    63  // the number of permitted events over an interval. It's optimized for runtime
    64  // and memory efficiency.
    65  func New(c *Config) (limiter.Store, error) {
    66  	if c == nil {
    67  		c = new(Config)
    68  	}
    69  
    70  	tokens := uint64(1)
    71  	if c.Tokens > 0 {
    72  		tokens = c.Tokens
    73  	}
    74  
    75  	interval := 1 * time.Second
    76  	if c.Interval > 0 {
    77  		interval = c.Interval
    78  	}
    79  
    80  	sweepInterval := 6 * time.Hour
    81  	if c.SweepInterval > 0 {
    82  		sweepInterval = c.SweepInterval
    83  	}
    84  
    85  	sweepMinTTL := 12 * time.Hour
    86  	if c.SweepMinTTL > 0 {
    87  		sweepMinTTL = c.SweepMinTTL
    88  	}
    89  
    90  	initialAlloc := 4096
    91  	if c.InitialAlloc > 0 {
    92  		initialAlloc = c.InitialAlloc
    93  	}
    94  
    95  	s := &store{
    96  		tokens:   tokens,
    97  		interval: interval,
    98  
    99  		sweepInterval: sweepInterval,
   100  		sweepMinTTL:   uint64(sweepMinTTL),
   101  
   102  		data:   make(map[string]*bucket, initialAlloc),
   103  		stopCh: make(chan struct{}),
   104  	}
   105  	go s.purge()
   106  	return s, nil
   107  }
   108  
   109  // Take attempts to remove a token from the named key. If the take is
   110  // successful, it returns true, otherwise false. It also returns the configured
   111  // limit, remaining tokens, and reset time.
   112  func (s *store) Take(ctx context.Context, key string) (uint64, uint64, uint64, bool, error) {
   113  	return s.TakeMany(ctx, key, 1)
   114  }
   115  
   116  func (s *store) TakeMany(ctx context.Context, key string, takeAmount uint64) (uint64, uint64, uint64, bool, error) {
   117  	// If the store is stopped, all requests are rejected.
   118  	if atomic.LoadUint32(&s.stopped) == 1 {
   119  		return 0, 0, 0, false, limiter.ErrStopped
   120  	}
   121  
   122  	// Acquire a read lock first - this allows other to concurrently check limits
   123  	// without taking a full lock.
   124  	s.dataLock.RLock()
   125  	if b, ok := s.data[key]; ok {
   126  		s.dataLock.RUnlock()
   127  		return b.take(takeAmount)
   128  	}
   129  	s.dataLock.RUnlock()
   130  
   131  	// Unfortunately we did not find the key in the map. Take out a full lock. We
   132  	// have to check if the key exists again, because it's possible another
   133  	// goroutine created it between our shared lock and exclusive lock.
   134  	s.dataLock.Lock()
   135  	if b, ok := s.data[key]; ok {
   136  		s.dataLock.Unlock()
   137  		return b.take(takeAmount)
   138  	}
   139  
   140  	// This is the first time we've seen this entry (or it's been garbage
   141  	// collected), so create the bucket and take an initial request.
   142  	b := newBucket(s.tokens, s.interval)
   143  
   144  	// Add it to the map and take.
   145  	s.data[key] = b
   146  	s.dataLock.Unlock()
   147  	return b.take(takeAmount)
   148  }
   149  
   150  // Get retrieves the information about the key, if any exists.
   151  func (s *store) Get(ctx context.Context, key string) (uint64, uint64, error) {
   152  	// If the store is stopped, all requests are rejected.
   153  	if atomic.LoadUint32(&s.stopped) == 1 {
   154  		return 0, 0, limiter.ErrStopped
   155  	}
   156  
   157  	// Acquire a read lock first - this allows other to concurrently check limits
   158  	// without taking a full lock.
   159  	s.dataLock.RLock()
   160  	if b, ok := s.data[key]; ok {
   161  		s.dataLock.RUnlock()
   162  		return b.get()
   163  	}
   164  	s.dataLock.RUnlock()
   165  
   166  	return 0, 0, nil
   167  }
   168  
   169  // Set configures the bucket-specific tokens and interval.
   170  func (s *store) Set(ctx context.Context, key string, tokens uint64, interval time.Duration) error {
   171  	s.dataLock.Lock()
   172  	b := newBucket(tokens, interval)
   173  	s.data[key] = b
   174  	s.dataLock.Unlock()
   175  	return nil
   176  }
   177  
   178  // Burst adds the provided value to the bucket's currently available tokens.
   179  func (s *store) Burst(ctx context.Context, key string, tokens uint64) error {
   180  	s.dataLock.Lock()
   181  	if b, ok := s.data[key]; ok {
   182  		b.lock.Lock()
   183  		s.dataLock.Unlock()
   184  		b.availableTokens = b.availableTokens + tokens
   185  		b.lock.Unlock()
   186  		return nil
   187  	}
   188  
   189  	// If we got this far, there's no current record for the key.
   190  	b := newBucket(s.tokens+tokens, s.interval)
   191  	s.data[key] = b
   192  	s.dataLock.Unlock()
   193  	return nil
   194  }
   195  
   196  // Close stops the memory limiter and cleans up any outstanding
   197  // sessions. You should always call Close() as it releases the memory consumed
   198  // by the map AND releases the tickers.
   199  func (s *store) Close(ctx context.Context) error {
   200  	if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) {
   201  		return nil
   202  	}
   203  
   204  	// Close the channel to prevent future purging.
   205  	close(s.stopCh)
   206  
   207  	// Delete all the things.
   208  	s.dataLock.Lock()
   209  	for k := range s.data {
   210  		delete(s.data, k)
   211  	}
   212  	s.dataLock.Unlock()
   213  	return nil
   214  }
   215  
   216  // purge continually iterates over the map and purges old values on the provided
   217  // sweep interval. Earlier designs used a go-function-per-item expiration, but
   218  // it actually generated *more* lock contention under normal use. The most
   219  // performant option with real-world data was a global garbage collection on a
   220  // fixed interval.
   221  func (s *store) purge() {
   222  	ticker := time.NewTicker(s.sweepInterval)
   223  	defer ticker.Stop()
   224  
   225  	for {
   226  		select {
   227  		case <-s.stopCh:
   228  			return
   229  		case <-ticker.C:
   230  		}
   231  
   232  		s.dataLock.Lock()
   233  		now := fasttime.Now()
   234  		for k, b := range s.data {
   235  			b.lock.Lock()
   236  			lastTime := b.startTime + (b.lastTick * uint64(b.interval))
   237  			b.lock.Unlock()
   238  
   239  			if now-lastTime > s.sweepMinTTL {
   240  				delete(s.data, k)
   241  			}
   242  		}
   243  		s.dataLock.Unlock()
   244  	}
   245  }
   246  
   247  // bucket is an internal wrapper around a taker.
   248  type bucket struct {
   249  	// startTime is the number of nanoseconds from unix epoch when this bucket was
   250  	// initially created.
   251  	startTime uint64
   252  
   253  	// maxTokens is the maximum number of tokens permitted on the bucket at any
   254  	// time. The number of available tokens will never exceed this value.
   255  	maxTokens uint64
   256  
   257  	// interval is the time at which ticking should occur.
   258  	interval time.Duration
   259  
   260  	// availableTokens is the current point-in-time number of tokens remaining.
   261  	availableTokens uint64
   262  
   263  	// lastTick is the last clock tick, used to re-calculate the number of tokens
   264  	// on the bucket.
   265  	lastTick uint64
   266  
   267  	// lock guards the mutable fields.
   268  	lock sync.Mutex
   269  }
   270  
   271  // newBucket creates a new bucket from the given tokens and interval.
   272  func newBucket(tokens uint64, interval time.Duration) *bucket {
   273  	b := &bucket{
   274  		startTime:       fasttime.Now(),
   275  		maxTokens:       tokens,
   276  		availableTokens: tokens,
   277  		interval:        interval,
   278  	}
   279  	return b
   280  }
   281  
   282  // get returns information about the bucket.
   283  func (b *bucket) get() (tokens uint64, remaining uint64, retErr error) {
   284  	b.lock.Lock()
   285  	defer b.lock.Unlock()
   286  
   287  	tokens = b.maxTokens
   288  	remaining = b.availableTokens
   289  	return
   290  }
   291  
   292  // take attempts to remove a token from the bucket. If there are no tokens
   293  // available and the clock has ticked forward, it recalculates the number of
   294  // tokens and retries. It returns the limit, remaining tokens, time until
   295  // refresh, and whether the take was successful.
   296  func (b *bucket) take(takeAmount uint64) (tokens uint64, remaining uint64, reset uint64, ok bool, retErr error) {
   297  	// Capture the current request time, current tick, and amount of time until
   298  	// the bucket resets.
   299  	now := fasttime.Now()
   300  	currTick := tick(b.startTime, now, b.interval)
   301  
   302  	tokens = b.maxTokens
   303  	reset = b.startTime + ((currTick + 1) * uint64(b.interval))
   304  
   305  	b.lock.Lock()
   306  	defer b.lock.Unlock()
   307  
   308  	// If we're on a new tick since last assessment, perform
   309  	// a full reset up to maxTokens.
   310  	if b.lastTick < currTick {
   311  		b.availableTokens = b.maxTokens
   312  		b.lastTick = currTick
   313  	}
   314  
   315  	if b.availableTokens >= takeAmount {
   316  		b.availableTokens -= takeAmount
   317  		ok = true
   318  		remaining = b.availableTokens
   319  	}
   320  	return
   321  }
   322  
   323  // tick is the total number of times the current interval has occurred between
   324  // when the time started (start) and the current time (curr). For example, if
   325  // the start time was 12:30pm and it's currently 1:00pm, and the interval was 5
   326  // minutes, tick would return 6 because 1:00pm is the 6th 5-minute tick. Note
   327  // that tick would return 5 at 12:59pm, because it hasn't reached the 6th tick
   328  // yet.
   329  func tick(start, curr uint64, interval time.Duration) uint64 {
   330  	return (curr - start) / uint64(interval.Nanoseconds())
   331  }