github.com/shuguocloud/go-zero@v1.3.0/core/limit/tokenlimit.go (about)

     1  package limit
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"sync"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/shuguocloud/go-zero/core/logx"
    11  	"github.com/shuguocloud/go-zero/core/stores/redis"
    12  	xrate "golang.org/x/time/rate"
    13  )
    14  
    15  const (
    16  	// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
    17  	// KEYS[1] as tokens_key
    18  	// KEYS[2] as timestamp_key
    19  	script = `local rate = tonumber(ARGV[1])
    20  local capacity = tonumber(ARGV[2])
    21  local now = tonumber(ARGV[3])
    22  local requested = tonumber(ARGV[4])
    23  local fill_time = capacity/rate
    24  local ttl = math.floor(fill_time*2)
    25  local last_tokens = tonumber(redis.call("get", KEYS[1]))
    26  if last_tokens == nil then
    27      last_tokens = capacity
    28  end
    29  
    30  local last_refreshed = tonumber(redis.call("get", KEYS[2]))
    31  if last_refreshed == nil then
    32      last_refreshed = 0
    33  end
    34  
    35  local delta = math.max(0, now-last_refreshed)
    36  local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
    37  local allowed = filled_tokens >= requested
    38  local new_tokens = filled_tokens
    39  if allowed then
    40      new_tokens = filled_tokens - requested
    41  end
    42  
    43  redis.call("setex", KEYS[1], ttl, new_tokens)
    44  redis.call("setex", KEYS[2], ttl, now)
    45  
    46  return allowed`
    47  	tokenFormat     = "{%s}.tokens"
    48  	timestampFormat = "{%s}.ts"
    49  	pingInterval    = time.Millisecond * 100
    50  )
    51  
    52  // A TokenLimiter controls how frequently events are allowed to happen with in one second.
    53  type TokenLimiter struct {
    54  	rate           int
    55  	burst          int
    56  	store          *redis.Redis
    57  	tokenKey       string
    58  	timestampKey   string
    59  	rescueLock     sync.Mutex
    60  	redisAlive     uint32
    61  	rescueLimiter  *xrate.Limiter
    62  	monitorStarted bool
    63  }
    64  
    65  // NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits
    66  // bursts of at most burst tokens.
    67  func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter {
    68  	tokenKey := fmt.Sprintf(tokenFormat, key)
    69  	timestampKey := fmt.Sprintf(timestampFormat, key)
    70  
    71  	return &TokenLimiter{
    72  		rate:          rate,
    73  		burst:         burst,
    74  		store:         store,
    75  		tokenKey:      tokenKey,
    76  		timestampKey:  timestampKey,
    77  		redisAlive:    1,
    78  		rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst),
    79  	}
    80  }
    81  
    82  // Allow is shorthand for AllowN(time.Now(), 1).
    83  func (lim *TokenLimiter) Allow() bool {
    84  	return lim.AllowN(time.Now(), 1)
    85  }
    86  
    87  // AllowN reports whether n events may happen at time now.
    88  // Use this method if you intend to drop / skip events that exceed the rate.
    89  // Otherwise, use Reserve or Wait.
    90  func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
    91  	return lim.reserveN(now, n)
    92  }
    93  
    94  func (lim *TokenLimiter) reserveN(now time.Time, n int) bool {
    95  	if atomic.LoadUint32(&lim.redisAlive) == 0 {
    96  		return lim.rescueLimiter.AllowN(now, n)
    97  	}
    98  
    99  	resp, err := lim.store.Eval(
   100  		script,
   101  		[]string{
   102  			lim.tokenKey,
   103  			lim.timestampKey,
   104  		},
   105  		[]string{
   106  			strconv.Itoa(lim.rate),
   107  			strconv.Itoa(lim.burst),
   108  			strconv.FormatInt(now.Unix(), 10),
   109  			strconv.Itoa(n),
   110  		})
   111  	// redis allowed == false
   112  	// Lua boolean false -> r Nil bulk reply
   113  	if err == redis.Nil {
   114  		return false
   115  	}
   116  	if err != nil {
   117  		logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
   118  		lim.startMonitor()
   119  		return lim.rescueLimiter.AllowN(now, n)
   120  	}
   121  
   122  	code, ok := resp.(int64)
   123  	if !ok {
   124  		logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp)
   125  		lim.startMonitor()
   126  		return lim.rescueLimiter.AllowN(now, n)
   127  	}
   128  
   129  	// redis allowed == true
   130  	// Lua boolean true -> r integer reply with value of 1
   131  	return code == 1
   132  }
   133  
   134  func (lim *TokenLimiter) startMonitor() {
   135  	lim.rescueLock.Lock()
   136  	defer lim.rescueLock.Unlock()
   137  
   138  	if lim.monitorStarted {
   139  		return
   140  	}
   141  
   142  	lim.monitorStarted = true
   143  	atomic.StoreUint32(&lim.redisAlive, 0)
   144  
   145  	go lim.waitForRedis()
   146  }
   147  
   148  func (lim *TokenLimiter) waitForRedis() {
   149  	ticker := time.NewTicker(pingInterval)
   150  	defer func() {
   151  		ticker.Stop()
   152  		lim.rescueLock.Lock()
   153  		lim.monitorStarted = false
   154  		lim.rescueLock.Unlock()
   155  	}()
   156  
   157  	for range ticker.C {
   158  		if lim.store.Ping() {
   159  			atomic.StoreUint32(&lim.redisAlive, 1)
   160  			return
   161  		}
   162  	}
   163  }