github.com/lingyao2333/mo-zero@v1.4.1/core/limit/tokenlimit.go (about)

     1  package limit
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"strconv"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/lingyao2333/mo-zero/core/logx"
    13  	"github.com/lingyao2333/mo-zero/core/stores/redis"
    14  	xrate "golang.org/x/time/rate"
    15  )
    16  
    17  const (
    18  	// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
    19  	// KEYS[1] as tokens_key
    20  	// KEYS[2] as timestamp_key
    21  	script = `local rate = tonumber(ARGV[1])
    22  local capacity = tonumber(ARGV[2])
    23  local now = tonumber(ARGV[3])
    24  local requested = tonumber(ARGV[4])
    25  local fill_time = capacity/rate
    26  local ttl = math.floor(fill_time*2)
    27  local last_tokens = tonumber(redis.call("get", KEYS[1]))
    28  if last_tokens == nil then
    29      last_tokens = capacity
    30  end
    31  
    32  local last_refreshed = tonumber(redis.call("get", KEYS[2]))
    33  if last_refreshed == nil then
    34      last_refreshed = 0
    35  end
    36  
    37  local delta = math.max(0, now-last_refreshed)
    38  local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
    39  local allowed = filled_tokens >= requested
    40  local new_tokens = filled_tokens
    41  if allowed then
    42      new_tokens = filled_tokens - requested
    43  end
    44  
    45  redis.call("setex", KEYS[1], ttl, new_tokens)
    46  redis.call("setex", KEYS[2], ttl, now)
    47  
    48  return allowed`
    49  	tokenFormat     = "{%s}.tokens"
    50  	timestampFormat = "{%s}.ts"
    51  	pingInterval    = time.Millisecond * 100
    52  )
    53  
    54  // A TokenLimiter controls how frequently events are allowed to happen with in one second.
    55  type TokenLimiter struct {
    56  	rate           int
    57  	burst          int
    58  	store          *redis.Redis
    59  	tokenKey       string
    60  	timestampKey   string
    61  	rescueLock     sync.Mutex
    62  	redisAlive     uint32
    63  	monitorStarted bool
    64  	rescueLimiter  *xrate.Limiter
    65  }
    66  
    67  // NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits
    68  // bursts of at most burst tokens.
    69  func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter {
    70  	tokenKey := fmt.Sprintf(tokenFormat, key)
    71  	timestampKey := fmt.Sprintf(timestampFormat, key)
    72  
    73  	return &TokenLimiter{
    74  		rate:          rate,
    75  		burst:         burst,
    76  		store:         store,
    77  		tokenKey:      tokenKey,
    78  		timestampKey:  timestampKey,
    79  		redisAlive:    1,
    80  		rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst),
    81  	}
    82  }
    83  
    84  // Allow is shorthand for AllowN(time.Now(), 1).
    85  func (lim *TokenLimiter) Allow() bool {
    86  	return lim.AllowN(time.Now(), 1)
    87  }
    88  
    89  // AllowCtx is shorthand for AllowNCtx(ctx,time.Now(), 1) with incoming context.
    90  func (lim *TokenLimiter) AllowCtx(ctx context.Context) bool {
    91  	return lim.AllowNCtx(ctx, time.Now(), 1)
    92  }
    93  
    94  // AllowN reports whether n events may happen at time now.
    95  // Use this method if you intend to drop / skip events that exceed the rate.
    96  // Otherwise, use Reserve or Wait.
    97  func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
    98  	return lim.reserveN(context.Background(), now, n)
    99  }
   100  
   101  // AllowNCtx reports whether n events may happen at time now with incoming context.
   102  // Use this method if you intend to drop / skip events that exceed the rate.
   103  // Otherwise, use Reserve or Wait.
   104  func (lim *TokenLimiter) AllowNCtx(ctx context.Context, now time.Time, n int) bool {
   105  	return lim.reserveN(ctx, now, n)
   106  }
   107  
   108  func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) bool {
   109  	if atomic.LoadUint32(&lim.redisAlive) == 0 {
   110  		return lim.rescueLimiter.AllowN(now, n)
   111  	}
   112  
   113  	resp, err := lim.store.EvalCtx(ctx,
   114  		script,
   115  		[]string{
   116  			lim.tokenKey,
   117  			lim.timestampKey,
   118  		},
   119  		[]string{
   120  			strconv.Itoa(lim.rate),
   121  			strconv.Itoa(lim.burst),
   122  			strconv.FormatInt(now.Unix(), 10),
   123  			strconv.Itoa(n),
   124  		})
   125  	// redis allowed == false
   126  	// Lua boolean false -> r Nil bulk reply
   127  	if err == redis.Nil {
   128  		return false
   129  	}
   130  	if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
   131  		logx.Errorf("fail to use rate limiter: %s", err)
   132  		return false
   133  	}
   134  	if err != nil {
   135  		logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
   136  		lim.startMonitor()
   137  		return lim.rescueLimiter.AllowN(now, n)
   138  	}
   139  
   140  	code, ok := resp.(int64)
   141  	if !ok {
   142  		logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp)
   143  		lim.startMonitor()
   144  		return lim.rescueLimiter.AllowN(now, n)
   145  	}
   146  
   147  	// redis allowed == true
   148  	// Lua boolean true -> r integer reply with value of 1
   149  	return code == 1
   150  }
   151  
   152  func (lim *TokenLimiter) startMonitor() {
   153  	lim.rescueLock.Lock()
   154  	defer lim.rescueLock.Unlock()
   155  
   156  	if lim.monitorStarted {
   157  		return
   158  	}
   159  
   160  	lim.monitorStarted = true
   161  	atomic.StoreUint32(&lim.redisAlive, 0)
   162  
   163  	go lim.waitForRedis()
   164  }
   165  
   166  func (lim *TokenLimiter) waitForRedis() {
   167  	ticker := time.NewTicker(pingInterval)
   168  	defer func() {
   169  		ticker.Stop()
   170  		lim.rescueLock.Lock()
   171  		lim.monitorStarted = false
   172  		lim.rescueLock.Unlock()
   173  	}()
   174  
   175  	for range ticker.C {
   176  		if lim.store.Ping() {
   177  			atomic.StoreUint32(&lim.redisAlive, 1)
   178  			return
   179  		}
   180  	}
   181  }