github.com/teng231/glock@v1.1.11/limiter.go (about)

     1  package glock
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"time"
     8  
     9  	"github.com/go-redis/redis/v8"
    10  	"github.com/go-redis/redis_rate/v9"
    11  	"github.com/golang-module/carbon/v2"
    12  )
    13  
    14  // Copyright (c) 2017 Pavel Pravosud
    15  // https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua
    16  type ILimiter interface {
    17  	// Allow is access comming request and increase counter
    18  	Allow(key string, per string, count int) error
    19  	// Immediate reset counter
    20  	Reset(key string) error
    21  	// Allow using with duration = day
    22  	// AllowInDay(key string, count int) error
    23  	AllowInWeek(key string, count int) error
    24  }
    25  type Limiter struct {
    26  	client   *redis.Client
    27  	timelock time.Duration
    28  	limiter  *redis_rate.Limiter
    29  	tz       string // timezone
    30  }
    31  
    32  const (
    33  	Second     = "second"
    34  	Minute     = "minute"
    35  	Hour       = "hour"
    36  	Day        = "day"
    37  	Week       = "week"
    38  	Restricted = "restricted"
    39  )
    40  
    41  func StartLimiter(cf *ConnectConfig) (*Limiter, error) {
    42  	client := redis.NewClient(&redis.Options{
    43  		Addr:     cf.RedisAddr,
    44  		Password: cf.RedisPw, // no password set
    45  		DB:       cf.RedisDb, // use default DB
    46  	})
    47  	limiter := redis_rate.NewLimiter(client)
    48  	if cf.Timezone == "" {
    49  		cf.Timezone = "Asia/Ho_Chi_Minh"
    50  	}
    51  	return &Limiter{client, cf.Timelock, limiter, cf.Timezone}, nil
    52  }
    53  
    54  // CreateLimiter deprecated
    55  func CreateLimiter(addr, pw string, timelock time.Duration) (*Limiter, error) {
    56  	client := redis.NewClient(&redis.Options{
    57  		Addr:     addr,
    58  		Password: pw, // no password set
    59  		DB:       1,  // use default DB
    60  	})
    61  	limiter := redis_rate.NewLimiter(client)
    62  	return &Limiter{client, timelock, limiter, "Asia/Ho_Chi_Minh"}, nil
    63  }
    64  
    65  func (r *Limiter) Allow(key string, per string, count int) error {
    66  	ctx, cancel := context.WithTimeout(context.Background(), r.timelock)
    67  	defer cancel()
    68  	switch per {
    69  	case Second:
    70  		res, err := r.limiter.Allow(ctx, key, redis_rate.PerSecond(count))
    71  		if err != nil {
    72  			return err
    73  		}
    74  		// log.Print("allowed:", res.Allowed, " remaining:", res.Remaining)
    75  		if res.Allowed == 0 {
    76  			return errors.New(Restricted)
    77  		}
    78  	case Minute:
    79  		res, err := r.limiter.Allow(ctx, key, redis_rate.PerMinute(count))
    80  		if err != nil {
    81  			return err
    82  		}
    83  		// log.Print("allowed:", res.Allowed, " remaining:", res.Remaining)
    84  		if res.Allowed == 0 {
    85  			return errors.New(Restricted)
    86  		}
    87  	case Hour:
    88  		res, err := r.limiter.Allow(ctx, key, redis_rate.PerHour(count))
    89  		if err != nil {
    90  			return err
    91  		}
    92  		// log.Print("allowed:", res.Allowed, " remaining:", res.Remaining)
    93  		if res.Allowed == 0 {
    94  			return errors.New(Restricted)
    95  		}
    96  	case Day:
    97  		// return r.AllowInDay(key, count)
    98  		res, err := r.limiter.Allow(ctx, key, redis_rate.PerHour(24))
    99  		if err != nil {
   100  			return err
   101  		}
   102  		// log.Print("allowed:", res.Allowed, " remaining:", res.Remaining)
   103  		if res.Allowed == 0 {
   104  			return errors.New(Restricted)
   105  		}
   106  	case Week:
   107  		return r.AllowInWeek(key, count)
   108  	}
   109  	return nil
   110  }
   111  
   112  func (r *Limiter) Reset(key string) error {
   113  	ctx, cancel := context.WithTimeout(context.Background(), r.timelock)
   114  	defer cancel()
   115  	// return r.client.Del(ctx, key).Err()
   116  	return r.limiter.Reset(ctx, key)
   117  }
   118  
   119  // func (r *Limiter) AllowInDay(key string, count int) error {
   120  // 	ctx, cancel := context.WithTimeout(context.Background(), r.timelock)
   121  // 	defer cancel()
   122  // 	day := carbon.Now(r.tz).Carbon2Time().Unix() / 86400
   123  // 	key = fmt.Sprintf("%s_%d", key, day)
   124  // 	log.Print(key)
   125  // 	currentValue, err := r.client.Get(ctx, key).Int64()
   126  // 	if err == redis.Nil {
   127  // 		// set time expire 1 day for this key
   128  // 		if err := r.client.SetNX(ctx, key, count, 86400*time.Second).Err(); err != nil {
   129  // 			return err
   130  // 		}
   131  // 		currentValue, _ = r.client.Get(ctx, key).Int64()
   132  // 	}
   133  // 	if currentValue <= 0 {
   134  // 		return errors.New(Restricted)
   135  // 	}
   136  // 	remain, err := r.client.Decr(ctx, key).Result()
   137  // 	if err != nil {
   138  // 		return err
   139  // 	}
   140  // 	if remain < 0 {
   141  // 		return errors.New(Restricted)
   142  // 	}
   143  // 	return nil
   144  // }
   145  
   146  func (r *Limiter) AllowInWeek(key string, count int) error {
   147  	ctx, cancel := context.WithTimeout(context.Background(), r.timelock)
   148  	defer cancel()
   149  	now := carbon.Now(r.tz)
   150  	endOfWeekday := now.SetWeekStartsAt(carbon.Monday).EndOfWeek().EndOfDay()
   151  	hours := now.DiffAbsInHours(endOfWeekday) + 1
   152  	// formula: wts := (ts / 86400 + 3) / 7 | because: 1/1/1970 is thus
   153  	key = fmt.Sprintf("%s_%d", key, (now.Carbon2Time().Unix()/86400+3)/7)
   154  	currentValue, err := r.client.Get(ctx, key).Int64()
   155  	if err == redis.Nil {
   156  		if err := r.client.SetNX(ctx, key, count, time.Duration(hours)*time.Hour).Err(); err != nil {
   157  			return err
   158  		}
   159  		currentValue, _ = r.client.Get(ctx, key).Int64()
   160  	}
   161  	if currentValue <= 0 {
   162  		return errors.New(Restricted)
   163  	}
   164  
   165  	remain, err := r.client.Decr(ctx, key).Result()
   166  	if err != nil {
   167  		return err
   168  	}
   169  	if remain < 0 {
   170  		return errors.New(Restricted)
   171  	}
   172  	return nil
   173  }