github.com/ulule/limiter/v3@v3.11.3-0.20230613131926-4cb9c1da4633/drivers/store/redis/store.go (about)

     1  package redis
     2  
     3  import (
     4  	"context"
     5  	"strings"
     6  	"sync"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/pkg/errors"
    11  	libredis "github.com/redis/go-redis/v9"
    12  
    13  	"github.com/ulule/limiter/v3"
    14  	"github.com/ulule/limiter/v3/drivers/store/common"
    15  )
    16  
    17  const (
    18  	luaIncrScript = `
    19  local key = KEYS[1]
    20  local count = tonumber(ARGV[1])
    21  local ttl = tonumber(ARGV[2])
    22  local ret = redis.call("incrby", key, ARGV[1])
    23  if ret == count then
    24  	if ttl > 0 then
    25  		redis.call("pexpire", key, ARGV[2])
    26  	end
    27  	return {ret, ttl}
    28  end
    29  ttl = redis.call("pttl", key)
    30  return {ret, ttl}
    31  `
    32  	luaPeekScript = `
    33  local key = KEYS[1]
    34  local v = redis.call("get", key)
    35  if v == false then
    36  	return {0, 0}
    37  end
    38  local ttl = redis.call("pttl", key)
    39  return {tonumber(v), ttl}
    40  `
    41  )
    42  
    43  // Client is an interface thats allows to use a redis cluster or a redis single client seamlessly.
    44  type Client interface {
    45  	Get(ctx context.Context, key string) *libredis.StringCmd
    46  	Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *libredis.StatusCmd
    47  	Watch(ctx context.Context, handler func(*libredis.Tx) error, keys ...string) error
    48  	Del(ctx context.Context, keys ...string) *libredis.IntCmd
    49  	SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *libredis.BoolCmd
    50  	EvalSha(ctx context.Context, sha string, keys []string, args ...interface{}) *libredis.Cmd
    51  	ScriptLoad(ctx context.Context, script string) *libredis.StringCmd
    52  }
    53  
    54  // Store is the redis store.
    55  type Store struct {
    56  	// Prefix used for the key.
    57  	Prefix string
    58  	// MaxRetry is the maximum number of retry under race conditions.
    59  	// Deprecated: this option is no longer required since all operations are atomic now.
    60  	MaxRetry int
    61  	// client used to communicate with redis server.
    62  	client Client
    63  	// luaMutex is a mutex used to avoid concurrent access on luaIncrSHA and luaPeekSHA.
    64  	luaMutex sync.RWMutex
    65  	// luaLoaded is used for CAS and reduce pressure on luaMutex.
    66  	luaLoaded uint32
    67  	// luaIncrSHA is the SHA of increase and expire key script.
    68  	luaIncrSHA string
    69  	// luaPeekSHA is the SHA of peek and expire key script.
    70  	luaPeekSHA string
    71  }
    72  
    73  // NewStore returns an instance of redis store with defaults.
    74  func NewStore(client Client) (limiter.Store, error) {
    75  	return NewStoreWithOptions(client, limiter.StoreOptions{
    76  		Prefix:          limiter.DefaultPrefix,
    77  		CleanUpInterval: limiter.DefaultCleanUpInterval,
    78  		MaxRetry:        limiter.DefaultMaxRetry,
    79  	})
    80  }
    81  
    82  // NewStoreWithOptions returns an instance of redis store with options.
    83  func NewStoreWithOptions(client Client, options limiter.StoreOptions) (limiter.Store, error) {
    84  	store := &Store{
    85  		client:   client,
    86  		Prefix:   options.Prefix,
    87  		MaxRetry: options.MaxRetry,
    88  	}
    89  
    90  	err := store.preloadLuaScripts(context.Background())
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	return store, nil
    96  }
    97  
    98  // Increment increments the limit by given count & gives back the new limit for given identifier
    99  func (store *Store) Increment(ctx context.Context, key string, count int64, rate limiter.Rate) (limiter.Context, error) {
   100  	cmd := store.evalSHA(ctx, store.getLuaIncrSHA, []string{store.getCacheKey(key)}, count, rate.Period.Milliseconds())
   101  	return currentContext(cmd, rate)
   102  }
   103  
   104  // Get returns the limit for given identifier.
   105  func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
   106  	cmd := store.evalSHA(ctx, store.getLuaIncrSHA, []string{store.getCacheKey(key)}, 1, rate.Period.Milliseconds())
   107  	return currentContext(cmd, rate)
   108  }
   109  
   110  // Peek returns the limit for given identifier, without modification on current values.
   111  func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
   112  	cmd := store.evalSHA(ctx, store.getLuaPeekSHA, []string{store.getCacheKey(key)})
   113  	count, ttl, err := parseCountAndTTL(cmd)
   114  	if err != nil {
   115  		return limiter.Context{}, err
   116  	}
   117  
   118  	now := time.Now()
   119  	expiration := now.Add(rate.Period)
   120  	if ttl > 0 {
   121  		expiration = now.Add(time.Duration(ttl) * time.Millisecond)
   122  	}
   123  
   124  	return common.GetContextFromState(now, rate, expiration, count), nil
   125  }
   126  
   127  // Reset returns the limit for given identifier which is set to zero.
   128  func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
   129  	_, err := store.client.Del(ctx, store.getCacheKey(key)).Result()
   130  	if err != nil {
   131  		return limiter.Context{}, err
   132  	}
   133  
   134  	count := int64(0)
   135  	now := time.Now()
   136  	expiration := now.Add(rate.Period)
   137  
   138  	return common.GetContextFromState(now, rate, expiration, count), nil
   139  }
   140  
   141  // getCacheKey returns the full path for an identifier.
   142  func (store *Store) getCacheKey(key string) string {
   143  	buffer := strings.Builder{}
   144  	buffer.WriteString(store.Prefix)
   145  	buffer.WriteString(":")
   146  	buffer.WriteString(key)
   147  	return buffer.String()
   148  }
   149  
   150  // preloadLuaScripts preloads the "incr" and "peek" lua scripts.
   151  func (store *Store) preloadLuaScripts(ctx context.Context) error {
   152  	// Verify if we need to load lua scripts.
   153  	// Inspired by sync.Once.
   154  	if atomic.LoadUint32(&store.luaLoaded) == 0 {
   155  		return store.loadLuaScripts(ctx)
   156  	}
   157  	return nil
   158  }
   159  
   160  // reloadLuaScripts forces a reload of "incr" and "peek" lua scripts.
   161  func (store *Store) reloadLuaScripts(ctx context.Context) error {
   162  	// Reset lua scripts loaded state.
   163  	// Inspired by sync.Once.
   164  	atomic.StoreUint32(&store.luaLoaded, 0)
   165  	return store.loadLuaScripts(ctx)
   166  }
   167  
   168  // loadLuaScripts load "incr" and "peek" lua scripts.
   169  // WARNING: Please use preloadLuaScripts or reloadLuaScripts, instead of this one.
   170  func (store *Store) loadLuaScripts(ctx context.Context) error {
   171  	store.luaMutex.Lock()
   172  	defer store.luaMutex.Unlock()
   173  
   174  	// Check if scripts are already loaded.
   175  	if atomic.LoadUint32(&store.luaLoaded) != 0 {
   176  		return nil
   177  	}
   178  
   179  	luaIncrSHA, err := store.client.ScriptLoad(ctx, luaIncrScript).Result()
   180  	if err != nil {
   181  		return errors.Wrap(err, `failed to load "incr" lua script`)
   182  	}
   183  
   184  	luaPeekSHA, err := store.client.ScriptLoad(ctx, luaPeekScript).Result()
   185  	if err != nil {
   186  		return errors.Wrap(err, `failed to load "peek" lua script`)
   187  	}
   188  
   189  	store.luaIncrSHA = luaIncrSHA
   190  	store.luaPeekSHA = luaPeekSHA
   191  
   192  	atomic.StoreUint32(&store.luaLoaded, 1)
   193  
   194  	return nil
   195  }
   196  
   197  // getLuaIncrSHA returns a "thread-safe" value for luaIncrSHA.
   198  func (store *Store) getLuaIncrSHA() string {
   199  	store.luaMutex.RLock()
   200  	defer store.luaMutex.RUnlock()
   201  	return store.luaIncrSHA
   202  }
   203  
   204  // getLuaPeekSHA returns a "thread-safe" value for luaPeekSHA.
   205  func (store *Store) getLuaPeekSHA() string {
   206  	store.luaMutex.RLock()
   207  	defer store.luaMutex.RUnlock()
   208  	return store.luaPeekSHA
   209  }
   210  
   211  // evalSHA eval the redis lua sha and load the scripts if missing.
   212  func (store *Store) evalSHA(ctx context.Context, getSha func() string,
   213  	keys []string, args ...interface{}) *libredis.Cmd {
   214  
   215  	cmd := store.client.EvalSha(ctx, getSha(), keys, args...)
   216  	err := cmd.Err()
   217  	if err == nil || !isLuaScriptGone(err) {
   218  		return cmd
   219  	}
   220  
   221  	err = store.reloadLuaScripts(ctx)
   222  	if err != nil {
   223  		cmd = libredis.NewCmd(ctx)
   224  		cmd.SetErr(err)
   225  		return cmd
   226  	}
   227  
   228  	return store.client.EvalSha(ctx, getSha(), keys, args...)
   229  }
   230  
   231  // isLuaScriptGone returns if the error is a missing lua script from redis server.
   232  func isLuaScriptGone(err error) bool {
   233  	return strings.HasPrefix(err.Error(), "NOSCRIPT")
   234  }
   235  
   236  // parseCountAndTTL parse count and ttl from lua script output.
   237  func parseCountAndTTL(cmd *libredis.Cmd) (int64, int64, error) {
   238  	result, err := cmd.Result()
   239  	if err != nil {
   240  		return 0, 0, errors.Wrap(err, "an error has occurred with redis command")
   241  	}
   242  
   243  	fields, ok := result.([]interface{})
   244  	if !ok || len(fields) != 2 {
   245  		return 0, 0, errors.New("two elements in result were expected")
   246  	}
   247  
   248  	count, ok1 := fields[0].(int64)
   249  	ttl, ok2 := fields[1].(int64)
   250  	if !ok1 || !ok2 {
   251  		return 0, 0, errors.New("type of the count and/or ttl should be number")
   252  	}
   253  
   254  	return count, ttl, nil
   255  }
   256  
   257  func currentContext(cmd *libredis.Cmd, rate limiter.Rate) (limiter.Context, error) {
   258  	count, ttl, err := parseCountAndTTL(cmd)
   259  	if err != nil {
   260  		return limiter.Context{}, err
   261  	}
   262  
   263  	now := time.Now()
   264  	expiration := now.Add(rate.Period)
   265  	if ttl > 0 {
   266  		expiration = now.Add(time.Duration(ttl) * time.Millisecond)
   267  	}
   268  
   269  	return common.GetContextFromState(now, rate, expiration, count), nil
   270  }