github.com/shuguocloud/go-zero@v1.3.0/core/limit/periodlimit.go (about) 1 package limit 2 3 import ( 4 "errors" 5 "strconv" 6 "time" 7 8 "github.com/shuguocloud/go-zero/core/stores/redis" 9 ) 10 11 // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key 12 const periodScript = `local limit = tonumber(ARGV[1]) 13 local window = tonumber(ARGV[2]) 14 local current = redis.call("INCRBY", KEYS[1], 1) 15 if current == 1 then 16 redis.call("expire", KEYS[1], window) 17 return 1 18 elseif current < limit then 19 return 1 20 elseif current == limit then 21 return 2 22 else 23 return 0 24 end` 25 26 const ( 27 // Unknown means not initialized state. 28 Unknown = iota 29 // Allowed means allowed state. 30 Allowed 31 // HitQuota means this request exactly hit the quota. 32 HitQuota 33 // OverQuota means passed the quota. 34 OverQuota 35 36 internalOverQuota = 0 37 internalAllowed = 1 38 internalHitQuota = 2 39 ) 40 41 // ErrUnknownCode is an error that represents unknown status code. 42 var ErrUnknownCode = errors.New("unknown status code") 43 44 type ( 45 // PeriodOption defines the method to customize a PeriodLimit. 46 PeriodOption func(l *PeriodLimit) 47 48 // A PeriodLimit is used to limit requests during a period of time. 49 PeriodLimit struct { 50 period int 51 quota int 52 limitStore *redis.Redis 53 keyPrefix string 54 align bool 55 } 56 ) 57 58 // NewPeriodLimit returns a PeriodLimit with given parameters. 59 func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string, 60 opts ...PeriodOption) *PeriodLimit { 61 limiter := &PeriodLimit{ 62 period: period, 63 quota: quota, 64 limitStore: limitStore, 65 keyPrefix: keyPrefix, 66 } 67 68 for _, opt := range opts { 69 opt(limiter) 70 } 71 72 return limiter 73 } 74 75 // Take requests a permit, it returns the permit state. 76 func (h *PeriodLimit) Take(key string) (int, error) { 77 resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{ 78 strconv.Itoa(h.quota), 79 strconv.Itoa(h.calcExpireSeconds()), 80 }) 81 if err != nil { 82 return Unknown, err 83 } 84 85 code, ok := resp.(int64) 86 if !ok { 87 return Unknown, ErrUnknownCode 88 } 89 90 switch code { 91 case internalOverQuota: 92 return OverQuota, nil 93 case internalAllowed: 94 return Allowed, nil 95 case internalHitQuota: 96 return HitQuota, nil 97 default: 98 return Unknown, ErrUnknownCode 99 } 100 } 101 102 func (h *PeriodLimit) calcExpireSeconds() int { 103 if h.align { 104 now := time.Now() 105 _, offset := now.Zone() 106 unix := now.Unix() + int64(offset) 107 return h.period - int(unix%int64(h.period)) 108 } 109 110 return h.period 111 } 112 113 // Align returns a func to customize a PeriodLimit with alignment. 114 // For example, if we want to limit end users with 5 sms verification messages every day, 115 // we need to align with the local timezone and the start of the day. 116 func Align() PeriodOption { 117 return func(l *PeriodLimit) { 118 l.align = true 119 } 120 }