github.com/shuguocloud/go-zero@v1.3.0/core/limit/periodlimit.go (about)

     1  package limit
     2  
     3  import (
     4  	"errors"
     5  	"strconv"
     6  	"time"
     7  
     8  	"github.com/shuguocloud/go-zero/core/stores/redis"
     9  )
    10  
    11  // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
    12  const periodScript = `local limit = tonumber(ARGV[1])
    13  local window = tonumber(ARGV[2])
    14  local current = redis.call("INCRBY", KEYS[1], 1)
    15  if current == 1 then
    16      redis.call("expire", KEYS[1], window)
    17      return 1
    18  elseif current < limit then
    19      return 1
    20  elseif current == limit then
    21      return 2
    22  else
    23      return 0
    24  end`
    25  
    26  const (
    27  	// Unknown means not initialized state.
    28  	Unknown = iota
    29  	// Allowed means allowed state.
    30  	Allowed
    31  	// HitQuota means this request exactly hit the quota.
    32  	HitQuota
    33  	// OverQuota means passed the quota.
    34  	OverQuota
    35  
    36  	internalOverQuota = 0
    37  	internalAllowed   = 1
    38  	internalHitQuota  = 2
    39  )
    40  
    41  // ErrUnknownCode is an error that represents unknown status code.
    42  var ErrUnknownCode = errors.New("unknown status code")
    43  
    44  type (
    45  	// PeriodOption defines the method to customize a PeriodLimit.
    46  	PeriodOption func(l *PeriodLimit)
    47  
    48  	// A PeriodLimit is used to limit requests during a period of time.
    49  	PeriodLimit struct {
    50  		period     int
    51  		quota      int
    52  		limitStore *redis.Redis
    53  		keyPrefix  string
    54  		align      bool
    55  	}
    56  )
    57  
    58  // NewPeriodLimit returns a PeriodLimit with given parameters.
    59  func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string,
    60  	opts ...PeriodOption) *PeriodLimit {
    61  	limiter := &PeriodLimit{
    62  		period:     period,
    63  		quota:      quota,
    64  		limitStore: limitStore,
    65  		keyPrefix:  keyPrefix,
    66  	}
    67  
    68  	for _, opt := range opts {
    69  		opt(limiter)
    70  	}
    71  
    72  	return limiter
    73  }
    74  
    75  // Take requests a permit, it returns the permit state.
    76  func (h *PeriodLimit) Take(key string) (int, error) {
    77  	resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{
    78  		strconv.Itoa(h.quota),
    79  		strconv.Itoa(h.calcExpireSeconds()),
    80  	})
    81  	if err != nil {
    82  		return Unknown, err
    83  	}
    84  
    85  	code, ok := resp.(int64)
    86  	if !ok {
    87  		return Unknown, ErrUnknownCode
    88  	}
    89  
    90  	switch code {
    91  	case internalOverQuota:
    92  		return OverQuota, nil
    93  	case internalAllowed:
    94  		return Allowed, nil
    95  	case internalHitQuota:
    96  		return HitQuota, nil
    97  	default:
    98  		return Unknown, ErrUnknownCode
    99  	}
   100  }
   101  
   102  func (h *PeriodLimit) calcExpireSeconds() int {
   103  	if h.align {
   104  		now := time.Now()
   105  		_, offset := now.Zone()
   106  		unix := now.Unix() + int64(offset)
   107  		return h.period - int(unix%int64(h.period))
   108  	}
   109  
   110  	return h.period
   111  }
   112  
   113  // Align returns a func to customize a PeriodLimit with alignment.
   114  // For example, if we want to limit end users with 5 sms verification messages every day,
   115  // we need to align with the local timezone and the start of the day.
   116  func Align() PeriodOption {
   117  	return func(l *PeriodLimit) {
   118  		l.align = true
   119  	}
   120  }