github.com/andy2046/gopie@v0.7.0/pkg/ratelimit/sliding_window.go (about)

     1  package ratelimit
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  	"os"
     7  	"strconv"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/go-redis/redis"
    12  )
    13  
    14  type (
    15  	// SlidingWindowLimiter implements a limiter with sliding window counter.
    16  	SlidingWindowLimiter struct {
    17  		limit  Limit
    18  		expire int
    19  		store  Store
    20  	}
    21  
    22  	// Store represents a store for limiter state.
    23  	Store interface {
    24  		// Incr add `increment` to field `timestamp` in `key`
    25  		Incr(key string, timestamp int64, increment int) error
    26  		// SetIncr set `key` and add `increment` to field `timestamp` in `key`
    27  		SetIncr(key string, timestamp int64, increment int) error
    28  		// Expire set `key` to expire in `expire` seconds
    29  		Expire(key string, expire int) error
    30  		// Get returns value of field `timestamp` in `key`
    31  		Get(key string, timestamp int64) int
    32  		// Exists check if `key` exists
    33  		Exists(key string) bool
    34  	}
    35  
    36  	redisStore struct {
    37  		clientOpts *redis.Options
    38  		client     *redis.Client
    39  		mu         sync.RWMutex
    40  	}
    41  )
    42  
    43  var logger = log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile)
    44  
    45  // NotFound will be returned if it fails to get value.
    46  const NotFound = -1
    47  
    48  // NewSlidingWindowLimiter returns a new Limiter at rate `r` tokens per second,
    49  // and the key expires in `expire` seconds.
    50  func NewSlidingWindowLimiter(r Limit, expire int, store Store) *SlidingWindowLimiter {
    51  	return &SlidingWindowLimiter{
    52  		limit:  r,
    53  		store:  store,
    54  		expire: expire,
    55  	}
    56  }
    57  
    58  // Allow is the shortcut for AllowN(time.Now(), key, 1).
    59  func (s *SlidingWindowLimiter) Allow(key string) bool {
    60  	return s.AllowN(time.Now(), key, 1)
    61  }
    62  
    63  // AllowN checks whether `n` requests for `key` may happen at time `now`.
    64  func (s *SlidingWindowLimiter) AllowN(now time.Time, key string, n int) bool {
    65  	sec := timeToSeconds(now)
    66  	var err error
    67  
    68  	if existed := s.store.Exists(key); existed {
    69  		err = s.store.Incr(key, sec, n)
    70  	} else {
    71  		err = s.store.SetIncr(key, sec, n)
    72  		if err == nil {
    73  			s.store.Expire(key, s.expire)
    74  		}
    75  	}
    76  
    77  	if err != nil {
    78  		logger.Println(err)
    79  		return false
    80  	}
    81  
    82  	if count := s.store.Get(key, sec); count == NotFound || float64(count) > float64(s.limit) {
    83  		return false
    84  	}
    85  	return true
    86  }
    87  
    88  // NewRedisStore returns a new Redis Store.
    89  func NewRedisStore(clientOptions *redis.Options) (Store, error) {
    90  	r := &redisStore{
    91  		clientOpts: clientOptions,
    92  	}
    93  	err := r.newConnection()
    94  	return r, err
    95  }
    96  
    97  func (r *redisStore) newConnection() error {
    98  	r.mu.RLock()
    99  	if r.client != nil {
   100  		r.mu.RUnlock()
   101  		return nil
   102  	}
   103  	r.mu.RUnlock()
   104  
   105  	client := redis.NewClient(r.clientOpts)
   106  	_, err := client.Ping().Result()
   107  	if err != nil {
   108  		return err
   109  	}
   110  
   111  	r.mu.Lock()
   112  	r.client = client
   113  	r.mu.Unlock()
   114  	return nil
   115  }
   116  
   117  func (r *redisStore) Incr(key string, field int64, increment int) error {
   118  	r.mu.Lock()
   119  	defer r.mu.Unlock()
   120  	hIncrBy := r.client.HIncrBy(key, strconv.FormatInt(field, 10), int64(increment))
   121  	if err := hIncrBy.Err(); err != nil {
   122  		val := hIncrBy.Val()
   123  		return fmt.Errorf("ratelimit: Incr val=%v error=%v", val, err)
   124  	}
   125  	return nil
   126  }
   127  func (r *redisStore) SetIncr(key string, field int64, increment int) error {
   128  	return r.Incr(key, field, increment)
   129  }
   130  
   131  func (r *redisStore) Expire(key string, timeout int) error {
   132  	r.mu.Lock()
   133  	defer r.mu.Unlock()
   134  	expire := r.client.Expire(key, time.Duration(timeout)*time.Second)
   135  	if err := expire.Err(); err != nil {
   136  		val, ttl := expire.Val(), r.client.TTL(key)
   137  		return fmt.Errorf("ratelimit: Expire val=%v ttl=%v error=%v", val, ttl.Val(), err)
   138  	}
   139  	return nil
   140  }
   141  
   142  func (r *redisStore) Get(key string, field int64) int {
   143  	r.mu.RLock()
   144  	defer r.mu.RUnlock()
   145  	hGet := r.client.HGet(key, strconv.FormatInt(field, 10))
   146  	val := hGet.Val()
   147  	if err := hGet.Err(); err != nil || err == redis.Nil || val == "" {
   148  		logger.Printf("ratelimit: Get val=%v error=%v\n", val, err)
   149  		return NotFound
   150  	}
   151  	n, _ := strconv.Atoi(val)
   152  	return n
   153  }
   154  
   155  func (r *redisStore) Exists(key string) bool {
   156  	r.mu.RLock()
   157  	defer r.mu.RUnlock()
   158  	n, err := r.client.Exists(key).Result()
   159  	if err != nil || n == 0 {
   160  		logger.Printf("ratelimit: Exists record count=%v error=%v\n", n, err)
   161  		return false
   162  	}
   163  	return true
   164  }
   165  
   166  func timeToSeconds(t time.Time) int64 {
   167  	newT := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), 0, time.UTC)
   168  	return newT.Unix()
   169  }