github.com/sethvargo/go-limiter@v1.0.0/memorystore/store.go (about) 1 // Package memorystore defines an in-memory storage system for limiting. 2 package memorystore 3 4 import ( 5 "context" 6 "sync" 7 "sync/atomic" 8 "time" 9 10 "github.com/sethvargo/go-limiter" 11 "github.com/sethvargo/go-limiter/internal/fasttime" 12 ) 13 14 var _ limiter.Store = (*store)(nil) 15 16 type store struct { 17 tokens uint64 18 interval time.Duration 19 20 sweepInterval time.Duration 21 sweepMinTTL uint64 22 23 data map[string]*bucket 24 dataLock sync.RWMutex 25 26 stopped uint32 27 stopCh chan struct{} 28 } 29 30 // Config is used as input to New. It defines the behavior of the storage 31 // system. 32 type Config struct { 33 // Tokens is the number of tokens to allow per interval. The default value is 34 // 1. 35 Tokens uint64 36 37 // Interval is the time interval upon which to enforce rate limiting. The 38 // default value is 1 second. 39 Interval time.Duration 40 41 // SweepInterval is the rate at which to run the garabage collection on stale 42 // entries. Setting this to a low value will optimize memory consumption, but 43 // will likely reduce performance and increase lock contention. Setting this 44 // to a high value will maximum throughput, but will increase the memory 45 // footprint. This can be tuned in combination with SweepMinTTL to control how 46 // long stale entires are kept. The default value is 6 hours. 47 SweepInterval time.Duration 48 49 // SweepMinTTL is the minimum amount of time a session must be inactive before 50 // clearing it from the entries. There's no validation, but this should be at 51 // least as high as your rate limit, or else the data store will purge records 52 // before they limit is applied. The default value is 12 hours. 53 SweepMinTTL time.Duration 54 55 // InitialAlloc is the size to use for the in-memory map. Go will 56 // automatically expand the buffer, but choosing higher number can trade 57 // memory consumption for performance as it limits the number of times the map 58 // needs to expand. The default value is 4096. 59 InitialAlloc int 60 } 61 62 // New creates an in-memory rate limiter that uses a bucketing model to limit 63 // the number of permitted events over an interval. It's optimized for runtime 64 // and memory efficiency. 65 func New(c *Config) (limiter.Store, error) { 66 if c == nil { 67 c = new(Config) 68 } 69 70 tokens := uint64(1) 71 if c.Tokens > 0 { 72 tokens = c.Tokens 73 } 74 75 interval := 1 * time.Second 76 if c.Interval > 0 { 77 interval = c.Interval 78 } 79 80 sweepInterval := 6 * time.Hour 81 if c.SweepInterval > 0 { 82 sweepInterval = c.SweepInterval 83 } 84 85 sweepMinTTL := 12 * time.Hour 86 if c.SweepMinTTL > 0 { 87 sweepMinTTL = c.SweepMinTTL 88 } 89 90 initialAlloc := 4096 91 if c.InitialAlloc > 0 { 92 initialAlloc = c.InitialAlloc 93 } 94 95 s := &store{ 96 tokens: tokens, 97 interval: interval, 98 99 sweepInterval: sweepInterval, 100 sweepMinTTL: uint64(sweepMinTTL), 101 102 data: make(map[string]*bucket, initialAlloc), 103 stopCh: make(chan struct{}), 104 } 105 go s.purge() 106 return s, nil 107 } 108 109 // Take attempts to remove a token from the named key. If the take is 110 // successful, it returns true, otherwise false. It also returns the configured 111 // limit, remaining tokens, and reset time. 112 func (s *store) Take(ctx context.Context, key string) (uint64, uint64, uint64, bool, error) { 113 // If the store is stopped, all requests are rejected. 114 if atomic.LoadUint32(&s.stopped) == 1 { 115 return 0, 0, 0, false, limiter.ErrStopped 116 } 117 118 // Acquire a read lock first - this allows other to concurrently check limits 119 // without taking a full lock. 120 s.dataLock.RLock() 121 if b, ok := s.data[key]; ok { 122 s.dataLock.RUnlock() 123 return b.take() 124 } 125 s.dataLock.RUnlock() 126 127 // Unfortunately we did not find the key in the map. Take out a full lock. We 128 // have to check if the key exists again, because it's possible another 129 // goroutine created it between our shared lock and exclusive lock. 130 s.dataLock.Lock() 131 if b, ok := s.data[key]; ok { 132 s.dataLock.Unlock() 133 return b.take() 134 } 135 136 // This is the first time we've seen this entry (or it's been garbage 137 // collected), so create the bucket and take an initial request. 138 b := newBucket(s.tokens, s.interval) 139 140 // Add it to the map and take. 141 s.data[key] = b 142 s.dataLock.Unlock() 143 return b.take() 144 } 145 146 // Get retrieves the information about the key, if any exists. 147 func (s *store) Get(ctx context.Context, key string) (uint64, uint64, error) { 148 // If the store is stopped, all requests are rejected. 149 if atomic.LoadUint32(&s.stopped) == 1 { 150 return 0, 0, limiter.ErrStopped 151 } 152 153 // Acquire a read lock first - this allows other to concurrently check limits 154 // without taking a full lock. 155 s.dataLock.RLock() 156 if b, ok := s.data[key]; ok { 157 s.dataLock.RUnlock() 158 return b.get() 159 } 160 s.dataLock.RUnlock() 161 162 return 0, 0, nil 163 } 164 165 // Set configures the bucket-specific tokens and interval. 166 func (s *store) Set(ctx context.Context, key string, tokens uint64, interval time.Duration) error { 167 s.dataLock.Lock() 168 b := newBucket(tokens, interval) 169 s.data[key] = b 170 s.dataLock.Unlock() 171 return nil 172 } 173 174 // Burst adds the provided value to the bucket's currently available tokens. 175 func (s *store) Burst(ctx context.Context, key string, tokens uint64) error { 176 s.dataLock.Lock() 177 if b, ok := s.data[key]; ok { 178 b.lock.Lock() 179 s.dataLock.Unlock() 180 b.availableTokens = b.availableTokens + tokens 181 b.lock.Unlock() 182 return nil 183 } 184 185 // If we got this far, there's no current record for the key. 186 b := newBucket(s.tokens+tokens, s.interval) 187 s.data[key] = b 188 s.dataLock.Unlock() 189 return nil 190 } 191 192 // Close stops the memory limiter and cleans up any outstanding 193 // sessions. You should always call Close() as it releases the memory consumed 194 // by the map AND releases the tickers. 195 func (s *store) Close(ctx context.Context) error { 196 if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) { 197 return nil 198 } 199 200 // Close the channel to prevent future purging. 201 close(s.stopCh) 202 203 // Delete all the things. 204 s.dataLock.Lock() 205 for k := range s.data { 206 delete(s.data, k) 207 } 208 s.dataLock.Unlock() 209 return nil 210 } 211 212 // purge continually iterates over the map and purges old values on the provided 213 // sweep interval. Earlier designs used a go-function-per-item expiration, but 214 // it actually generated *more* lock contention under normal use. The most 215 // performant option with real-world data was a global garbage collection on a 216 // fixed interval. 217 func (s *store) purge() { 218 ticker := time.NewTicker(s.sweepInterval) 219 defer ticker.Stop() 220 221 for { 222 select { 223 case <-s.stopCh: 224 return 225 case <-ticker.C: 226 } 227 228 s.dataLock.Lock() 229 now := fasttime.Now() 230 for k, b := range s.data { 231 b.lock.Lock() 232 lastTime := b.startTime + (b.lastTick * uint64(b.interval)) 233 b.lock.Unlock() 234 235 if now-lastTime > s.sweepMinTTL { 236 delete(s.data, k) 237 } 238 } 239 s.dataLock.Unlock() 240 } 241 } 242 243 // bucket is an internal wrapper around a taker. 244 type bucket struct { 245 // startTime is the number of nanoseconds from unix epoch when this bucket was 246 // initially created. 247 startTime uint64 248 249 // maxTokens is the maximum number of tokens permitted on the bucket at any 250 // time. The number of available tokens will never exceed this value. 251 maxTokens uint64 252 253 // interval is the time at which ticking should occur. 254 interval time.Duration 255 256 // availableTokens is the current point-in-time number of tokens remaining. 257 availableTokens uint64 258 259 // lastTick is the last clock tick, used to re-calculate the number of tokens 260 // on the bucket. 261 lastTick uint64 262 263 // lock guards the mutable fields. 264 lock sync.Mutex 265 } 266 267 // newBucket creates a new bucket from the given tokens and interval. 268 func newBucket(tokens uint64, interval time.Duration) *bucket { 269 b := &bucket{ 270 startTime: fasttime.Now(), 271 maxTokens: tokens, 272 availableTokens: tokens, 273 interval: interval, 274 } 275 return b 276 } 277 278 // get returns information about the bucket. 279 func (b *bucket) get() (tokens uint64, remaining uint64, retErr error) { 280 b.lock.Lock() 281 defer b.lock.Unlock() 282 283 tokens = b.maxTokens 284 remaining = b.availableTokens 285 return 286 } 287 288 // take attempts to remove a token from the bucket. If there are no tokens 289 // available and the clock has ticked forward, it recalculates the number of 290 // tokens and retries. It returns the limit, remaining tokens, time until 291 // refresh, and whether the take was successful. 292 func (b *bucket) take() (tokens uint64, remaining uint64, reset uint64, ok bool, retErr error) { 293 // Capture the current request time, current tick, and amount of time until 294 // the bucket resets. 295 now := fasttime.Now() 296 297 b.lock.Lock() 298 defer b.lock.Unlock() 299 300 // If the current time is before the start time, it means the server clock was 301 // reset to an earlier time. In that case, rebase to 0. 302 if now < b.startTime { 303 b.startTime = now 304 b.lastTick = 0 305 } 306 307 currTick := tick(b.startTime, now, b.interval) 308 309 tokens = b.maxTokens 310 reset = b.startTime + ((currTick + 1) * uint64(b.interval)) 311 312 // If we're on a new tick since last assessment, perform 313 // a full reset up to maxTokens. 314 if b.lastTick < currTick { 315 b.availableTokens = b.maxTokens 316 b.lastTick = currTick 317 } 318 319 if b.availableTokens > 0 { 320 b.availableTokens-- 321 ok = true 322 remaining = b.availableTokens 323 } 324 325 return 326 } 327 328 // tick is the total number of times the current interval has occurred between 329 // when the time started (start) and the current time (curr). For example, if 330 // the start time was 12:30pm and it's currently 1:00pm, and the interval was 5 331 // minutes, tick would return 6 because 1:00pm is the 6th 5-minute tick. Note 332 // that tick would return 5 at 12:59pm, because it hasn't reached the 6th tick 333 // yet. 334 func tick(start, curr uint64, interval time.Duration) uint64 { 335 return (curr - start) / uint64(interval.Nanoseconds()) 336 }