github.com/shuguocloud/go-zero@v1.3.0/core/limit/tokenlimit.go (about) 1 package limit 2 3 import ( 4 "fmt" 5 "strconv" 6 "sync" 7 "sync/atomic" 8 "time" 9 10 "github.com/shuguocloud/go-zero/core/logx" 11 "github.com/shuguocloud/go-zero/core/stores/redis" 12 xrate "golang.org/x/time/rate" 13 ) 14 15 const ( 16 // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key 17 // KEYS[1] as tokens_key 18 // KEYS[2] as timestamp_key 19 script = `local rate = tonumber(ARGV[1]) 20 local capacity = tonumber(ARGV[2]) 21 local now = tonumber(ARGV[3]) 22 local requested = tonumber(ARGV[4]) 23 local fill_time = capacity/rate 24 local ttl = math.floor(fill_time*2) 25 local last_tokens = tonumber(redis.call("get", KEYS[1])) 26 if last_tokens == nil then 27 last_tokens = capacity 28 end 29 30 local last_refreshed = tonumber(redis.call("get", KEYS[2])) 31 if last_refreshed == nil then 32 last_refreshed = 0 33 end 34 35 local delta = math.max(0, now-last_refreshed) 36 local filled_tokens = math.min(capacity, last_tokens+(delta*rate)) 37 local allowed = filled_tokens >= requested 38 local new_tokens = filled_tokens 39 if allowed then 40 new_tokens = filled_tokens - requested 41 end 42 43 redis.call("setex", KEYS[1], ttl, new_tokens) 44 redis.call("setex", KEYS[2], ttl, now) 45 46 return allowed` 47 tokenFormat = "{%s}.tokens" 48 timestampFormat = "{%s}.ts" 49 pingInterval = time.Millisecond * 100 50 ) 51 52 // A TokenLimiter controls how frequently events are allowed to happen with in one second. 53 type TokenLimiter struct { 54 rate int 55 burst int 56 store *redis.Redis 57 tokenKey string 58 timestampKey string 59 rescueLock sync.Mutex 60 redisAlive uint32 61 rescueLimiter *xrate.Limiter 62 monitorStarted bool 63 } 64 65 // NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits 66 // bursts of at most burst tokens. 67 func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter { 68 tokenKey := fmt.Sprintf(tokenFormat, key) 69 timestampKey := fmt.Sprintf(timestampFormat, key) 70 71 return &TokenLimiter{ 72 rate: rate, 73 burst: burst, 74 store: store, 75 tokenKey: tokenKey, 76 timestampKey: timestampKey, 77 redisAlive: 1, 78 rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst), 79 } 80 } 81 82 // Allow is shorthand for AllowN(time.Now(), 1). 83 func (lim *TokenLimiter) Allow() bool { 84 return lim.AllowN(time.Now(), 1) 85 } 86 87 // AllowN reports whether n events may happen at time now. 88 // Use this method if you intend to drop / skip events that exceed the rate. 89 // Otherwise, use Reserve or Wait. 90 func (lim *TokenLimiter) AllowN(now time.Time, n int) bool { 91 return lim.reserveN(now, n) 92 } 93 94 func (lim *TokenLimiter) reserveN(now time.Time, n int) bool { 95 if atomic.LoadUint32(&lim.redisAlive) == 0 { 96 return lim.rescueLimiter.AllowN(now, n) 97 } 98 99 resp, err := lim.store.Eval( 100 script, 101 []string{ 102 lim.tokenKey, 103 lim.timestampKey, 104 }, 105 []string{ 106 strconv.Itoa(lim.rate), 107 strconv.Itoa(lim.burst), 108 strconv.FormatInt(now.Unix(), 10), 109 strconv.Itoa(n), 110 }) 111 // redis allowed == false 112 // Lua boolean false -> r Nil bulk reply 113 if err == redis.Nil { 114 return false 115 } 116 if err != nil { 117 logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err) 118 lim.startMonitor() 119 return lim.rescueLimiter.AllowN(now, n) 120 } 121 122 code, ok := resp.(int64) 123 if !ok { 124 logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp) 125 lim.startMonitor() 126 return lim.rescueLimiter.AllowN(now, n) 127 } 128 129 // redis allowed == true 130 // Lua boolean true -> r integer reply with value of 1 131 return code == 1 132 } 133 134 func (lim *TokenLimiter) startMonitor() { 135 lim.rescueLock.Lock() 136 defer lim.rescueLock.Unlock() 137 138 if lim.monitorStarted { 139 return 140 } 141 142 lim.monitorStarted = true 143 atomic.StoreUint32(&lim.redisAlive, 0) 144 145 go lim.waitForRedis() 146 } 147 148 func (lim *TokenLimiter) waitForRedis() { 149 ticker := time.NewTicker(pingInterval) 150 defer func() { 151 ticker.Stop() 152 lim.rescueLock.Lock() 153 lim.monitorStarted = false 154 lim.rescueLock.Unlock() 155 }() 156 157 for range ticker.C { 158 if lim.store.Ping() { 159 atomic.StoreUint32(&lim.redisAlive, 1) 160 return 161 } 162 } 163 }