github.com/lingyao2333/mo-zero@v1.4.1/core/limit/tokenlimit.go (about) 1 package limit 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "strconv" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/lingyao2333/mo-zero/core/logx" 13 "github.com/lingyao2333/mo-zero/core/stores/redis" 14 xrate "golang.org/x/time/rate" 15 ) 16 17 const ( 18 // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key 19 // KEYS[1] as tokens_key 20 // KEYS[2] as timestamp_key 21 script = `local rate = tonumber(ARGV[1]) 22 local capacity = tonumber(ARGV[2]) 23 local now = tonumber(ARGV[3]) 24 local requested = tonumber(ARGV[4]) 25 local fill_time = capacity/rate 26 local ttl = math.floor(fill_time*2) 27 local last_tokens = tonumber(redis.call("get", KEYS[1])) 28 if last_tokens == nil then 29 last_tokens = capacity 30 end 31 32 local last_refreshed = tonumber(redis.call("get", KEYS[2])) 33 if last_refreshed == nil then 34 last_refreshed = 0 35 end 36 37 local delta = math.max(0, now-last_refreshed) 38 local filled_tokens = math.min(capacity, last_tokens+(delta*rate)) 39 local allowed = filled_tokens >= requested 40 local new_tokens = filled_tokens 41 if allowed then 42 new_tokens = filled_tokens - requested 43 end 44 45 redis.call("setex", KEYS[1], ttl, new_tokens) 46 redis.call("setex", KEYS[2], ttl, now) 47 48 return allowed` 49 tokenFormat = "{%s}.tokens" 50 timestampFormat = "{%s}.ts" 51 pingInterval = time.Millisecond * 100 52 ) 53 54 // A TokenLimiter controls how frequently events are allowed to happen with in one second. 55 type TokenLimiter struct { 56 rate int 57 burst int 58 store *redis.Redis 59 tokenKey string 60 timestampKey string 61 rescueLock sync.Mutex 62 redisAlive uint32 63 monitorStarted bool 64 rescueLimiter *xrate.Limiter 65 } 66 67 // NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits 68 // bursts of at most burst tokens. 69 func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter { 70 tokenKey := fmt.Sprintf(tokenFormat, key) 71 timestampKey := fmt.Sprintf(timestampFormat, key) 72 73 return &TokenLimiter{ 74 rate: rate, 75 burst: burst, 76 store: store, 77 tokenKey: tokenKey, 78 timestampKey: timestampKey, 79 redisAlive: 1, 80 rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst), 81 } 82 } 83 84 // Allow is shorthand for AllowN(time.Now(), 1). 85 func (lim *TokenLimiter) Allow() bool { 86 return lim.AllowN(time.Now(), 1) 87 } 88 89 // AllowCtx is shorthand for AllowNCtx(ctx,time.Now(), 1) with incoming context. 90 func (lim *TokenLimiter) AllowCtx(ctx context.Context) bool { 91 return lim.AllowNCtx(ctx, time.Now(), 1) 92 } 93 94 // AllowN reports whether n events may happen at time now. 95 // Use this method if you intend to drop / skip events that exceed the rate. 96 // Otherwise, use Reserve or Wait. 97 func (lim *TokenLimiter) AllowN(now time.Time, n int) bool { 98 return lim.reserveN(context.Background(), now, n) 99 } 100 101 // AllowNCtx reports whether n events may happen at time now with incoming context. 102 // Use this method if you intend to drop / skip events that exceed the rate. 103 // Otherwise, use Reserve or Wait. 104 func (lim *TokenLimiter) AllowNCtx(ctx context.Context, now time.Time, n int) bool { 105 return lim.reserveN(ctx, now, n) 106 } 107 108 func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) bool { 109 if atomic.LoadUint32(&lim.redisAlive) == 0 { 110 return lim.rescueLimiter.AllowN(now, n) 111 } 112 113 resp, err := lim.store.EvalCtx(ctx, 114 script, 115 []string{ 116 lim.tokenKey, 117 lim.timestampKey, 118 }, 119 []string{ 120 strconv.Itoa(lim.rate), 121 strconv.Itoa(lim.burst), 122 strconv.FormatInt(now.Unix(), 10), 123 strconv.Itoa(n), 124 }) 125 // redis allowed == false 126 // Lua boolean false -> r Nil bulk reply 127 if err == redis.Nil { 128 return false 129 } 130 if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { 131 logx.Errorf("fail to use rate limiter: %s", err) 132 return false 133 } 134 if err != nil { 135 logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err) 136 lim.startMonitor() 137 return lim.rescueLimiter.AllowN(now, n) 138 } 139 140 code, ok := resp.(int64) 141 if !ok { 142 logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp) 143 lim.startMonitor() 144 return lim.rescueLimiter.AllowN(now, n) 145 } 146 147 // redis allowed == true 148 // Lua boolean true -> r integer reply with value of 1 149 return code == 1 150 } 151 152 func (lim *TokenLimiter) startMonitor() { 153 lim.rescueLock.Lock() 154 defer lim.rescueLock.Unlock() 155 156 if lim.monitorStarted { 157 return 158 } 159 160 lim.monitorStarted = true 161 atomic.StoreUint32(&lim.redisAlive, 0) 162 163 go lim.waitForRedis() 164 } 165 166 func (lim *TokenLimiter) waitForRedis() { 167 ticker := time.NewTicker(pingInterval) 168 defer func() { 169 ticker.Stop() 170 lim.rescueLock.Lock() 171 lim.monitorStarted = false 172 lim.rescueLock.Unlock() 173 }() 174 175 for range ticker.C { 176 if lim.store.Ping() { 177 atomic.StoreUint32(&lim.redisAlive, 1) 178 return 179 } 180 } 181 }