github.com/andy2046/gopie@v0.7.0/pkg/ratelimit/sliding_window.go (about) 1 package ratelimit 2 3 import ( 4 "fmt" 5 "log" 6 "os" 7 "strconv" 8 "sync" 9 "time" 10 11 "github.com/go-redis/redis" 12 ) 13 14 type ( 15 // SlidingWindowLimiter implements a limiter with sliding window counter. 16 SlidingWindowLimiter struct { 17 limit Limit 18 expire int 19 store Store 20 } 21 22 // Store represents a store for limiter state. 23 Store interface { 24 // Incr add `increment` to field `timestamp` in `key` 25 Incr(key string, timestamp int64, increment int) error 26 // SetIncr set `key` and add `increment` to field `timestamp` in `key` 27 SetIncr(key string, timestamp int64, increment int) error 28 // Expire set `key` to expire in `expire` seconds 29 Expire(key string, expire int) error 30 // Get returns value of field `timestamp` in `key` 31 Get(key string, timestamp int64) int 32 // Exists check if `key` exists 33 Exists(key string) bool 34 } 35 36 redisStore struct { 37 clientOpts *redis.Options 38 client *redis.Client 39 mu sync.RWMutex 40 } 41 ) 42 43 var logger = log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile) 44 45 // NotFound will be returned if it fails to get value. 46 const NotFound = -1 47 48 // NewSlidingWindowLimiter returns a new Limiter at rate `r` tokens per second, 49 // and the key expires in `expire` seconds. 50 func NewSlidingWindowLimiter(r Limit, expire int, store Store) *SlidingWindowLimiter { 51 return &SlidingWindowLimiter{ 52 limit: r, 53 store: store, 54 expire: expire, 55 } 56 } 57 58 // Allow is the shortcut for AllowN(time.Now(), key, 1). 59 func (s *SlidingWindowLimiter) Allow(key string) bool { 60 return s.AllowN(time.Now(), key, 1) 61 } 62 63 // AllowN checks whether `n` requests for `key` may happen at time `now`. 64 func (s *SlidingWindowLimiter) AllowN(now time.Time, key string, n int) bool { 65 sec := timeToSeconds(now) 66 var err error 67 68 if existed := s.store.Exists(key); existed { 69 err = s.store.Incr(key, sec, n) 70 } else { 71 err = s.store.SetIncr(key, sec, n) 72 if err == nil { 73 s.store.Expire(key, s.expire) 74 } 75 } 76 77 if err != nil { 78 logger.Println(err) 79 return false 80 } 81 82 if count := s.store.Get(key, sec); count == NotFound || float64(count) > float64(s.limit) { 83 return false 84 } 85 return true 86 } 87 88 // NewRedisStore returns a new Redis Store. 89 func NewRedisStore(clientOptions *redis.Options) (Store, error) { 90 r := &redisStore{ 91 clientOpts: clientOptions, 92 } 93 err := r.newConnection() 94 return r, err 95 } 96 97 func (r *redisStore) newConnection() error { 98 r.mu.RLock() 99 if r.client != nil { 100 r.mu.RUnlock() 101 return nil 102 } 103 r.mu.RUnlock() 104 105 client := redis.NewClient(r.clientOpts) 106 _, err := client.Ping().Result() 107 if err != nil { 108 return err 109 } 110 111 r.mu.Lock() 112 r.client = client 113 r.mu.Unlock() 114 return nil 115 } 116 117 func (r *redisStore) Incr(key string, field int64, increment int) error { 118 r.mu.Lock() 119 defer r.mu.Unlock() 120 hIncrBy := r.client.HIncrBy(key, strconv.FormatInt(field, 10), int64(increment)) 121 if err := hIncrBy.Err(); err != nil { 122 val := hIncrBy.Val() 123 return fmt.Errorf("ratelimit: Incr val=%v error=%v", val, err) 124 } 125 return nil 126 } 127 func (r *redisStore) SetIncr(key string, field int64, increment int) error { 128 return r.Incr(key, field, increment) 129 } 130 131 func (r *redisStore) Expire(key string, timeout int) error { 132 r.mu.Lock() 133 defer r.mu.Unlock() 134 expire := r.client.Expire(key, time.Duration(timeout)*time.Second) 135 if err := expire.Err(); err != nil { 136 val, ttl := expire.Val(), r.client.TTL(key) 137 return fmt.Errorf("ratelimit: Expire val=%v ttl=%v error=%v", val, ttl.Val(), err) 138 } 139 return nil 140 } 141 142 func (r *redisStore) Get(key string, field int64) int { 143 r.mu.RLock() 144 defer r.mu.RUnlock() 145 hGet := r.client.HGet(key, strconv.FormatInt(field, 10)) 146 val := hGet.Val() 147 if err := hGet.Err(); err != nil || err == redis.Nil || val == "" { 148 logger.Printf("ratelimit: Get val=%v error=%v\n", val, err) 149 return NotFound 150 } 151 n, _ := strconv.Atoi(val) 152 return n 153 } 154 155 func (r *redisStore) Exists(key string) bool { 156 r.mu.RLock() 157 defer r.mu.RUnlock() 158 n, err := r.client.Exists(key).Result() 159 if err != nil || n == 0 { 160 logger.Printf("ratelimit: Exists record count=%v error=%v\n", n, err) 161 return false 162 } 163 return true 164 } 165 166 func timeToSeconds(t time.Time) int64 { 167 newT := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), 0, time.UTC) 168 return newT.Unix() 169 }