github.com/ulule/limiter/v3@v3.11.3-0.20230613131926-4cb9c1da4633/drivers/store/redis/store.go (about) 1 package redis 2 3 import ( 4 "context" 5 "strings" 6 "sync" 7 "sync/atomic" 8 "time" 9 10 "github.com/pkg/errors" 11 libredis "github.com/redis/go-redis/v9" 12 13 "github.com/ulule/limiter/v3" 14 "github.com/ulule/limiter/v3/drivers/store/common" 15 ) 16 17 const ( 18 luaIncrScript = ` 19 local key = KEYS[1] 20 local count = tonumber(ARGV[1]) 21 local ttl = tonumber(ARGV[2]) 22 local ret = redis.call("incrby", key, ARGV[1]) 23 if ret == count then 24 if ttl > 0 then 25 redis.call("pexpire", key, ARGV[2]) 26 end 27 return {ret, ttl} 28 end 29 ttl = redis.call("pttl", key) 30 return {ret, ttl} 31 ` 32 luaPeekScript = ` 33 local key = KEYS[1] 34 local v = redis.call("get", key) 35 if v == false then 36 return {0, 0} 37 end 38 local ttl = redis.call("pttl", key) 39 return {tonumber(v), ttl} 40 ` 41 ) 42 43 // Client is an interface thats allows to use a redis cluster or a redis single client seamlessly. 44 type Client interface { 45 Get(ctx context.Context, key string) *libredis.StringCmd 46 Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *libredis.StatusCmd 47 Watch(ctx context.Context, handler func(*libredis.Tx) error, keys ...string) error 48 Del(ctx context.Context, keys ...string) *libredis.IntCmd 49 SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *libredis.BoolCmd 50 EvalSha(ctx context.Context, sha string, keys []string, args ...interface{}) *libredis.Cmd 51 ScriptLoad(ctx context.Context, script string) *libredis.StringCmd 52 } 53 54 // Store is the redis store. 55 type Store struct { 56 // Prefix used for the key. 57 Prefix string 58 // MaxRetry is the maximum number of retry under race conditions. 59 // Deprecated: this option is no longer required since all operations are atomic now. 60 MaxRetry int 61 // client used to communicate with redis server. 62 client Client 63 // luaMutex is a mutex used to avoid concurrent access on luaIncrSHA and luaPeekSHA. 64 luaMutex sync.RWMutex 65 // luaLoaded is used for CAS and reduce pressure on luaMutex. 66 luaLoaded uint32 67 // luaIncrSHA is the SHA of increase and expire key script. 68 luaIncrSHA string 69 // luaPeekSHA is the SHA of peek and expire key script. 70 luaPeekSHA string 71 } 72 73 // NewStore returns an instance of redis store with defaults. 74 func NewStore(client Client) (limiter.Store, error) { 75 return NewStoreWithOptions(client, limiter.StoreOptions{ 76 Prefix: limiter.DefaultPrefix, 77 CleanUpInterval: limiter.DefaultCleanUpInterval, 78 MaxRetry: limiter.DefaultMaxRetry, 79 }) 80 } 81 82 // NewStoreWithOptions returns an instance of redis store with options. 83 func NewStoreWithOptions(client Client, options limiter.StoreOptions) (limiter.Store, error) { 84 store := &Store{ 85 client: client, 86 Prefix: options.Prefix, 87 MaxRetry: options.MaxRetry, 88 } 89 90 err := store.preloadLuaScripts(context.Background()) 91 if err != nil { 92 return nil, err 93 } 94 95 return store, nil 96 } 97 98 // Increment increments the limit by given count & gives back the new limit for given identifier 99 func (store *Store) Increment(ctx context.Context, key string, count int64, rate limiter.Rate) (limiter.Context, error) { 100 cmd := store.evalSHA(ctx, store.getLuaIncrSHA, []string{store.getCacheKey(key)}, count, rate.Period.Milliseconds()) 101 return currentContext(cmd, rate) 102 } 103 104 // Get returns the limit for given identifier. 105 func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { 106 cmd := store.evalSHA(ctx, store.getLuaIncrSHA, []string{store.getCacheKey(key)}, 1, rate.Period.Milliseconds()) 107 return currentContext(cmd, rate) 108 } 109 110 // Peek returns the limit for given identifier, without modification on current values. 111 func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { 112 cmd := store.evalSHA(ctx, store.getLuaPeekSHA, []string{store.getCacheKey(key)}) 113 count, ttl, err := parseCountAndTTL(cmd) 114 if err != nil { 115 return limiter.Context{}, err 116 } 117 118 now := time.Now() 119 expiration := now.Add(rate.Period) 120 if ttl > 0 { 121 expiration = now.Add(time.Duration(ttl) * time.Millisecond) 122 } 123 124 return common.GetContextFromState(now, rate, expiration, count), nil 125 } 126 127 // Reset returns the limit for given identifier which is set to zero. 128 func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { 129 _, err := store.client.Del(ctx, store.getCacheKey(key)).Result() 130 if err != nil { 131 return limiter.Context{}, err 132 } 133 134 count := int64(0) 135 now := time.Now() 136 expiration := now.Add(rate.Period) 137 138 return common.GetContextFromState(now, rate, expiration, count), nil 139 } 140 141 // getCacheKey returns the full path for an identifier. 142 func (store *Store) getCacheKey(key string) string { 143 buffer := strings.Builder{} 144 buffer.WriteString(store.Prefix) 145 buffer.WriteString(":") 146 buffer.WriteString(key) 147 return buffer.String() 148 } 149 150 // preloadLuaScripts preloads the "incr" and "peek" lua scripts. 151 func (store *Store) preloadLuaScripts(ctx context.Context) error { 152 // Verify if we need to load lua scripts. 153 // Inspired by sync.Once. 154 if atomic.LoadUint32(&store.luaLoaded) == 0 { 155 return store.loadLuaScripts(ctx) 156 } 157 return nil 158 } 159 160 // reloadLuaScripts forces a reload of "incr" and "peek" lua scripts. 161 func (store *Store) reloadLuaScripts(ctx context.Context) error { 162 // Reset lua scripts loaded state. 163 // Inspired by sync.Once. 164 atomic.StoreUint32(&store.luaLoaded, 0) 165 return store.loadLuaScripts(ctx) 166 } 167 168 // loadLuaScripts load "incr" and "peek" lua scripts. 169 // WARNING: Please use preloadLuaScripts or reloadLuaScripts, instead of this one. 170 func (store *Store) loadLuaScripts(ctx context.Context) error { 171 store.luaMutex.Lock() 172 defer store.luaMutex.Unlock() 173 174 // Check if scripts are already loaded. 175 if atomic.LoadUint32(&store.luaLoaded) != 0 { 176 return nil 177 } 178 179 luaIncrSHA, err := store.client.ScriptLoad(ctx, luaIncrScript).Result() 180 if err != nil { 181 return errors.Wrap(err, `failed to load "incr" lua script`) 182 } 183 184 luaPeekSHA, err := store.client.ScriptLoad(ctx, luaPeekScript).Result() 185 if err != nil { 186 return errors.Wrap(err, `failed to load "peek" lua script`) 187 } 188 189 store.luaIncrSHA = luaIncrSHA 190 store.luaPeekSHA = luaPeekSHA 191 192 atomic.StoreUint32(&store.luaLoaded, 1) 193 194 return nil 195 } 196 197 // getLuaIncrSHA returns a "thread-safe" value for luaIncrSHA. 198 func (store *Store) getLuaIncrSHA() string { 199 store.luaMutex.RLock() 200 defer store.luaMutex.RUnlock() 201 return store.luaIncrSHA 202 } 203 204 // getLuaPeekSHA returns a "thread-safe" value for luaPeekSHA. 205 func (store *Store) getLuaPeekSHA() string { 206 store.luaMutex.RLock() 207 defer store.luaMutex.RUnlock() 208 return store.luaPeekSHA 209 } 210 211 // evalSHA eval the redis lua sha and load the scripts if missing. 212 func (store *Store) evalSHA(ctx context.Context, getSha func() string, 213 keys []string, args ...interface{}) *libredis.Cmd { 214 215 cmd := store.client.EvalSha(ctx, getSha(), keys, args...) 216 err := cmd.Err() 217 if err == nil || !isLuaScriptGone(err) { 218 return cmd 219 } 220 221 err = store.reloadLuaScripts(ctx) 222 if err != nil { 223 cmd = libredis.NewCmd(ctx) 224 cmd.SetErr(err) 225 return cmd 226 } 227 228 return store.client.EvalSha(ctx, getSha(), keys, args...) 229 } 230 231 // isLuaScriptGone returns if the error is a missing lua script from redis server. 232 func isLuaScriptGone(err error) bool { 233 return strings.HasPrefix(err.Error(), "NOSCRIPT") 234 } 235 236 // parseCountAndTTL parse count and ttl from lua script output. 237 func parseCountAndTTL(cmd *libredis.Cmd) (int64, int64, error) { 238 result, err := cmd.Result() 239 if err != nil { 240 return 0, 0, errors.Wrap(err, "an error has occurred with redis command") 241 } 242 243 fields, ok := result.([]interface{}) 244 if !ok || len(fields) != 2 { 245 return 0, 0, errors.New("two elements in result were expected") 246 } 247 248 count, ok1 := fields[0].(int64) 249 ttl, ok2 := fields[1].(int64) 250 if !ok1 || !ok2 { 251 return 0, 0, errors.New("type of the count and/or ttl should be number") 252 } 253 254 return count, ttl, nil 255 } 256 257 func currentContext(cmd *libredis.Cmd, rate limiter.Rate) (limiter.Context, error) { 258 count, ttl, err := parseCountAndTTL(cmd) 259 if err != nil { 260 return limiter.Context{}, err 261 } 262 263 now := time.Now() 264 expiration := now.Add(rate.Period) 265 if ttl > 0 { 266 expiration = now.Add(time.Duration(ttl) * time.Millisecond) 267 } 268 269 return common.GetContextFromState(now, rate, expiration, count), nil 270 }