github.com/lingyao2333/mo-zero@v1.4.1/core/limit/periodlimit.go (about)

     1  package limit
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"strconv"
     7  	"time"
     8  
     9  	"github.com/lingyao2333/mo-zero/core/stores/redis"
    10  )
    11  
    12  // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
    13  const periodScript = `local limit = tonumber(ARGV[1])
    14  local window = tonumber(ARGV[2])
    15  local current = redis.call("INCRBY", KEYS[1], 1)
    16  if current == 1 then
    17      redis.call("expire", KEYS[1], window)
    18  end
    19  if current < limit then
    20      return 1
    21  elseif current == limit then
    22      return 2
    23  else
    24      return 0
    25  end`
    26  
    27  const (
    28  	// Unknown means not initialized state.
    29  	Unknown = iota
    30  	// Allowed means allowed state.
    31  	Allowed
    32  	// HitQuota means this request exactly hit the quota.
    33  	HitQuota
    34  	// OverQuota means passed the quota.
    35  	OverQuota
    36  
    37  	internalOverQuota = 0
    38  	internalAllowed   = 1
    39  	internalHitQuota  = 2
    40  )
    41  
    42  // ErrUnknownCode is an error that represents unknown status code.
    43  var ErrUnknownCode = errors.New("unknown status code")
    44  
    45  type (
    46  	// PeriodOption defines the method to customize a PeriodLimit.
    47  	PeriodOption func(l *PeriodLimit)
    48  
    49  	// A PeriodLimit is used to limit requests during a period of time.
    50  	PeriodLimit struct {
    51  		period     int
    52  		quota      int
    53  		limitStore *redis.Redis
    54  		keyPrefix  string
    55  		align      bool
    56  	}
    57  )
    58  
    59  // NewPeriodLimit returns a PeriodLimit with given parameters.
    60  func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string,
    61  	opts ...PeriodOption) *PeriodLimit {
    62  	limiter := &PeriodLimit{
    63  		period:     period,
    64  		quota:      quota,
    65  		limitStore: limitStore,
    66  		keyPrefix:  keyPrefix,
    67  	}
    68  
    69  	for _, opt := range opts {
    70  		opt(limiter)
    71  	}
    72  
    73  	return limiter
    74  }
    75  
    76  // Take requests a permit, it returns the permit state.
    77  func (h *PeriodLimit) Take(key string) (int, error) {
    78  	return h.TakeCtx(context.Background(), key)
    79  }
    80  
    81  // TakeCtx requests a permit with context, it returns the permit state.
    82  func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) {
    83  	resp, err := h.limitStore.EvalCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{
    84  		strconv.Itoa(h.quota),
    85  		strconv.Itoa(h.calcExpireSeconds()),
    86  	})
    87  	if err != nil {
    88  		return Unknown, err
    89  	}
    90  
    91  	code, ok := resp.(int64)
    92  	if !ok {
    93  		return Unknown, ErrUnknownCode
    94  	}
    95  
    96  	switch code {
    97  	case internalOverQuota:
    98  		return OverQuota, nil
    99  	case internalAllowed:
   100  		return Allowed, nil
   101  	case internalHitQuota:
   102  		return HitQuota, nil
   103  	default:
   104  		return Unknown, ErrUnknownCode
   105  	}
   106  }
   107  
   108  func (h *PeriodLimit) calcExpireSeconds() int {
   109  	if h.align {
   110  		now := time.Now()
   111  		_, offset := now.Zone()
   112  		unix := now.Unix() + int64(offset)
   113  		return h.period - int(unix%int64(h.period))
   114  	}
   115  
   116  	return h.period
   117  }
   118  
   119  // Align returns a func to customize a PeriodLimit with alignment.
   120  // For example, if we want to limit end users with 5 sms verification messages every day,
   121  // we need to align with the local timezone and the start of the day.
   122  func Align() PeriodOption {
   123  	return func(l *PeriodLimit) {
   124  		l.align = true
   125  	}
   126  }