github.com/lingyao2333/mo-zero@v1.4.1/core/limit/periodlimit.go (about) 1 package limit 2 3 import ( 4 "context" 5 "errors" 6 "strconv" 7 "time" 8 9 "github.com/lingyao2333/mo-zero/core/stores/redis" 10 ) 11 12 // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key 13 const periodScript = `local limit = tonumber(ARGV[1]) 14 local window = tonumber(ARGV[2]) 15 local current = redis.call("INCRBY", KEYS[1], 1) 16 if current == 1 then 17 redis.call("expire", KEYS[1], window) 18 end 19 if current < limit then 20 return 1 21 elseif current == limit then 22 return 2 23 else 24 return 0 25 end` 26 27 const ( 28 // Unknown means not initialized state. 29 Unknown = iota 30 // Allowed means allowed state. 31 Allowed 32 // HitQuota means this request exactly hit the quota. 33 HitQuota 34 // OverQuota means passed the quota. 35 OverQuota 36 37 internalOverQuota = 0 38 internalAllowed = 1 39 internalHitQuota = 2 40 ) 41 42 // ErrUnknownCode is an error that represents unknown status code. 43 var ErrUnknownCode = errors.New("unknown status code") 44 45 type ( 46 // PeriodOption defines the method to customize a PeriodLimit. 47 PeriodOption func(l *PeriodLimit) 48 49 // A PeriodLimit is used to limit requests during a period of time. 50 PeriodLimit struct { 51 period int 52 quota int 53 limitStore *redis.Redis 54 keyPrefix string 55 align bool 56 } 57 ) 58 59 // NewPeriodLimit returns a PeriodLimit with given parameters. 60 func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string, 61 opts ...PeriodOption) *PeriodLimit { 62 limiter := &PeriodLimit{ 63 period: period, 64 quota: quota, 65 limitStore: limitStore, 66 keyPrefix: keyPrefix, 67 } 68 69 for _, opt := range opts { 70 opt(limiter) 71 } 72 73 return limiter 74 } 75 76 // Take requests a permit, it returns the permit state. 77 func (h *PeriodLimit) Take(key string) (int, error) { 78 return h.TakeCtx(context.Background(), key) 79 } 80 81 // TakeCtx requests a permit with context, it returns the permit state. 82 func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) { 83 resp, err := h.limitStore.EvalCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{ 84 strconv.Itoa(h.quota), 85 strconv.Itoa(h.calcExpireSeconds()), 86 }) 87 if err != nil { 88 return Unknown, err 89 } 90 91 code, ok := resp.(int64) 92 if !ok { 93 return Unknown, ErrUnknownCode 94 } 95 96 switch code { 97 case internalOverQuota: 98 return OverQuota, nil 99 case internalAllowed: 100 return Allowed, nil 101 case internalHitQuota: 102 return HitQuota, nil 103 default: 104 return Unknown, ErrUnknownCode 105 } 106 } 107 108 func (h *PeriodLimit) calcExpireSeconds() int { 109 if h.align { 110 now := time.Now() 111 _, offset := now.Zone() 112 unix := now.Unix() + int64(offset) 113 return h.period - int(unix%int64(h.period)) 114 } 115 116 return h.period 117 } 118 119 // Align returns a func to customize a PeriodLimit with alignment. 120 // For example, if we want to limit end users with 5 sms verification messages every day, 121 // we need to align with the local timezone and the start of the day. 122 func Align() PeriodOption { 123 return func(l *PeriodLimit) { 124 l.align = true 125 } 126 }