github.com/dtm-labs/rockscache@v0.1.1/batch.go (about) 1 package rockscache 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "math/rand" 8 "runtime/debug" 9 "sync" 10 "time" 11 12 "github.com/lithammer/shortuuid" 13 "github.com/redis/go-redis/v9" 14 ) 15 16 var ( 17 errNeedFetch = errors.New("need fetch") 18 errNeedAsyncFetch = errors.New("need async fetch") 19 ) 20 21 func (c *Client) luaGetBatch(ctx context.Context, keys []string, owner string) ([]interface{}, error) { 22 res, err := callLua(ctx, c.rdb, `-- luaGetBatch 23 local rets = {} 24 for i, key in ipairs(KEYS) 25 do 26 local v = redis.call('HGET', key, 'value') 27 local lu = redis.call('HGET', key, 'lockUntil') 28 if lu ~= false and tonumber(lu) < tonumber(ARGV[1]) or lu == false and v == false then 29 redis.call('HSET', key, 'lockUntil', ARGV[2]) 30 redis.call('HSET', key, 'lockOwner', ARGV[3]) 31 table.insert(rets, { v, 'LOCKED' }) 32 else 33 table.insert(rets, {v, lu}) 34 end 35 end 36 return rets 37 `, keys, []interface{}{now(), now() + int64(c.Options.LockExpire/time.Second), owner}) 38 debugf("luaGetBatch return: %v, %v", res, err) 39 if err != nil { 40 return nil, err 41 } 42 return res.([]interface{}), nil 43 } 44 45 func (c *Client) luaSetBatch(ctx context.Context, keys []string, values []string, expires []int, owner string) error { 46 var vals = make([]interface{}, 0, 2+len(values)) 47 vals = append(vals, owner) 48 for _, v := range values { 49 vals = append(vals, v) 50 } 51 for _, ex := range expires { 52 vals = append(vals, ex) 53 } 54 _, err := callLua(ctx, c.rdb, `-- luaSetBatch 55 local n = #KEYS 56 for i, key in ipairs(KEYS) 57 do 58 local o = redis.call('HGET', key, 'lockOwner') 59 if o ~= ARGV[1] then 60 return 61 end 62 redis.call('HSET', key, 'value', ARGV[i+1]) 63 redis.call('HDEL', key, 'lockUntil') 64 redis.call('HDEL', key, 'lockOwner') 65 redis.call('EXPIRE', key, ARGV[i+1+n]) 66 end 67 `, keys, vals) 68 return err 69 } 70 71 func (c *Client) fetchBatch(ctx context.Context, keys []string, idxs []int, expire time.Duration, owner string, fn func(idxs []int) (map[int]string, error)) (map[int]string, error) { 72 defer func() { 73 if r := recover(); r != nil { 74 debug.PrintStack() 75 } 76 }() 77 data, err := fn(idxs) 78 if err != nil { 79 for _, idx := range idxs { 80 _ = c.UnlockForUpdate(ctx, keys[idx], owner) 81 } 82 return nil, err 83 } 84 85 if data == nil { 86 // incase data is nil 87 data = make(map[int]string) 88 } 89 90 var batchKeys []string 91 var batchValues []string 92 var batchExpires []int 93 94 for _, idx := range idxs { 95 v := data[idx] 96 ex := expire - c.Options.Delay - time.Duration(rand.Float64()*c.Options.RandomExpireAdjustment*float64(expire)) 97 if v == "" { 98 if c.Options.EmptyExpire == 0 { // if empty expire is 0, then delete the key 99 _ = c.rdb.Del(ctx, keys[idx]).Err() 100 if err != nil { 101 debugf("batch: del failed key=%s err:%s", keys[idx], err.Error()) 102 } 103 continue 104 } 105 ex = c.Options.EmptyExpire 106 107 data[idx] = v // incase idx not in data 108 } 109 batchKeys = append(batchKeys, keys[idx]) 110 batchValues = append(batchValues, v) 111 batchExpires = append(batchExpires, int(ex/time.Second)) 112 } 113 114 err = c.luaSetBatch(ctx, batchKeys, batchValues, batchExpires, owner) 115 if err != nil { 116 debugf("batch: luaSetBatch failed keys=%s err:%s", keys, err.Error()) 117 } 118 return data, nil 119 } 120 121 func (c *Client) keysIdx(keys []string) (idxs []int) { 122 for i := range keys { 123 idxs = append(idxs, i) 124 } 125 return idxs 126 } 127 128 type pair struct { 129 idx int 130 data string 131 err error 132 } 133 134 func (c *Client) weakFetchBatch(ctx context.Context, keys []string, expire time.Duration, fn func(idxs []int) (map[int]string, error)) (map[int]string, error) { 135 debugf("batch: weakFetch keys=%+v", keys) 136 var result = make(map[int]string) 137 owner := shortuuid.New() 138 var toGet, toFetch, toFetchAsync []int 139 140 // read from redis without sleep 141 rs, err := c.luaGetBatch(ctx, keys, owner) 142 if err != nil { 143 return nil, err 144 } 145 for i, v := range rs { 146 r := v.([]interface{}) 147 148 if r[0] == nil { 149 if r[1] == locked { 150 toFetch = append(toFetch, i) 151 } else { 152 toGet = append(toGet, i) 153 } 154 continue 155 } 156 157 if r[1] == locked { 158 toFetchAsync = append(toFetchAsync, i) 159 // fallthrough with old data 160 } // else new data 161 162 result[i] = r[0].(string) 163 } 164 165 if len(toFetchAsync) > 0 { 166 go func(idxs []int) { 167 debugf("batch weak: async fetch keys=%+v", keys) 168 _, _ = c.fetchBatch(ctx, keys, idxs, expire, owner, fn) 169 }(toFetchAsync) 170 toFetchAsync = toFetchAsync[:0] // reset toFetch 171 } 172 173 if len(toFetch) > 0 { 174 // batch fetch 175 fetched, err := c.fetchBatch(ctx, keys, toFetch, expire, owner, fn) 176 if err != nil { 177 return nil, err 178 } 179 for _, k := range toFetch { 180 result[k] = fetched[k] 181 } 182 toFetch = toFetch[:0] // reset toFetch 183 } 184 185 if len(toGet) > 0 { 186 // read from redis and sleep to wait 187 var wg sync.WaitGroup 188 189 var ch = make(chan pair, len(toGet)) 190 for _, idx := range toGet { 191 wg.Add(1) 192 go func(i int) { 193 defer wg.Done() 194 r, err := c.luaGet(ctx, keys[i], owner) 195 for err == nil && r[0] == nil && r[1].(string) != locked { 196 debugf("batch weak: empty result for %s locked by other, so sleep %s", keys[i], c.Options.LockSleep.String()) 197 time.Sleep(c.Options.LockSleep) 198 r, err = c.luaGet(ctx, keys[i], owner) 199 } 200 if err != nil { 201 ch <- pair{idx: i, data: "", err: err} 202 return 203 } 204 if r[1] != locked { // normal value 205 ch <- pair{idx: i, data: r[0].(string), err: nil} 206 return 207 } 208 if r[0] == nil { 209 ch <- pair{idx: i, data: "", err: errNeedFetch} 210 return 211 } 212 ch <- pair{idx: i, data: "", err: errNeedAsyncFetch} 213 }(idx) 214 } 215 wg.Wait() 216 close(ch) 217 218 for p := range ch { 219 if p.err != nil { 220 switch p.err { 221 case errNeedFetch: 222 toFetch = append(toFetch, p.idx) 223 continue 224 case errNeedAsyncFetch: 225 toFetchAsync = append(toFetchAsync, p.idx) 226 continue 227 default: 228 } 229 return nil, p.err 230 } 231 result[p.idx] = p.data 232 } 233 } 234 235 if len(toFetchAsync) > 0 { 236 go func(idxs []int) { 237 debugf("batch weak: async 2 fetch keys=%+v", keys) 238 _, _ = c.fetchBatch(ctx, keys, idxs, expire, owner, fn) 239 }(toFetchAsync) 240 } 241 242 if len(toFetch) > 0 { 243 // batch fetch 244 fetched, err := c.fetchBatch(ctx, keys, toFetch, expire, owner, fn) 245 if err != nil { 246 return nil, err 247 } 248 for _, k := range toFetch { 249 result[k] = fetched[k] 250 } 251 } 252 253 return result, nil 254 } 255 256 func (c *Client) strongFetchBatch(ctx context.Context, keys []string, expire time.Duration, fn func(idxs []int) (map[int]string, error)) (map[int]string, error) { 257 debugf("batch: strongFetch keys=%+v", keys) 258 var result = make(map[int]string) 259 owner := shortuuid.New() 260 var toGet, toFetch []int 261 262 // read from redis without sleep 263 rs, err := c.luaGetBatch(ctx, keys, owner) 264 if err != nil { 265 return nil, err 266 } 267 for i, v := range rs { 268 r := v.([]interface{}) 269 if r[1] == nil { // normal value 270 result[i] = r[0].(string) 271 continue 272 } 273 274 if r[1] != locked { // locked by other 275 debugf("batch: locked by other, continue idx=%d", i) 276 toGet = append(toGet, i) 277 continue 278 } 279 280 // locked for fetch 281 toFetch = append(toFetch, i) 282 } 283 284 if len(toFetch) > 0 { 285 // batch fetch 286 fetched, err := c.fetchBatch(ctx, keys, toFetch, expire, owner, fn) 287 if err != nil { 288 return nil, err 289 } 290 for _, k := range toFetch { 291 result[k] = fetched[k] 292 } 293 toFetch = toFetch[:0] // reset toFetch 294 } 295 296 if len(toGet) > 0 { 297 // read from redis and sleep to wait 298 var wg sync.WaitGroup 299 var ch = make(chan pair, len(toGet)) 300 for _, idx := range toGet { 301 wg.Add(1) 302 go func(i int) { 303 defer wg.Done() 304 r, err := c.luaGet(ctx, keys[i], owner) 305 for err == nil && r[1] != nil && r[1] != locked { // locked by other 306 debugf("batch: locked by other, so sleep %s", c.Options.LockSleep) 307 time.Sleep(c.Options.LockSleep) 308 r, err = c.luaGet(ctx, keys[i], owner) 309 } 310 if err != nil { 311 ch <- pair{idx: i, data: "", err: err} 312 return 313 } 314 if r[1] != locked { // normal value 315 ch <- pair{idx: i, data: r[0].(string), err: nil} 316 return 317 } 318 // locked for update 319 ch <- pair{idx: i, data: "", err: errNeedFetch} 320 }(idx) 321 } 322 wg.Wait() 323 close(ch) 324 for p := range ch { 325 if p.err != nil { 326 if p.err == errNeedFetch { 327 toFetch = append(toFetch, p.idx) 328 continue 329 } 330 return nil, p.err 331 } 332 result[p.idx] = p.data 333 } 334 } 335 336 if len(toFetch) > 0 { 337 // batch fetch 338 fetched, err := c.fetchBatch(ctx, keys, toFetch, expire, owner, fn) 339 if err != nil { 340 return nil, err 341 } 342 for _, k := range toFetch { 343 result[k] = fetched[k] 344 } 345 } 346 347 return result, nil 348 } 349 350 // FetchBatch returns a map with values indexed by index of keys list. 351 // 1. the first parameter is the keys list of the data 352 // 2. the second parameter is the data expiration time 353 // 3. the third parameter is the batch data fetch function which is called when the cache does not exist 354 // the parameter of the batch data fetch function is the index list of those keys 355 // missing in cache, which can be used to form a batch query for missing data. 356 // the return value of the batch data fetch function is a map, with key of the 357 // index and value of the corresponding data in form of string 358 func (c *Client) FetchBatch(keys []string, expire time.Duration, fn func(idxs []int) (map[int]string, error)) (map[int]string, error) { 359 return c.FetchBatch2(c.Options.Context, keys, expire, fn) 360 } 361 362 // FetchBatch2 is same with FetchBatch, except that a user defined context.Context can be provided. 363 func (c *Client) FetchBatch2(ctx context.Context, keys []string, expire time.Duration, fn func(idxs []int) (map[int]string, error)) (map[int]string, error) { 364 if c.Options.DisableCacheRead { 365 return fn(c.keysIdx(keys)) 366 } else if c.Options.StrongConsistency { 367 return c.strongFetchBatch(ctx, keys, expire, fn) 368 } 369 return c.weakFetchBatch(ctx, keys, expire, fn) 370 } 371 372 // TagAsDeletedBatch a key list, the keys in list will expire after delay time. 373 func (c *Client) TagAsDeletedBatch(keys []string) error { 374 return c.TagAsDeletedBatch2(c.Options.Context, keys) 375 } 376 377 // TagAsDeletedBatch2 a key list, the keys in list will expire after delay time. 378 func (c *Client) TagAsDeletedBatch2(ctx context.Context, keys []string) error { 379 if c.Options.DisableCacheDelete { 380 return nil 381 } 382 debugf("batch deleting: keys=%v", keys) 383 luaFn := func(con redisConn) error { 384 _, err := callLua(ctx, con, ` -- luaDeleteBatch 385 for i, key in ipairs(KEYS) do 386 redis.call('HSET', key, 'lockUntil', 0) 387 redis.call('HDEL', key, 'lockOwner') 388 redis.call('EXPIRE', key, ARGV[1]) 389 end 390 `, keys, []interface{}{int64(c.Options.Delay / time.Second)}) 391 return err 392 } 393 if c.Options.WaitReplicas > 0 { 394 err := luaFn(c.rdb) 395 cmd := redis.NewCmd(ctx, "WAIT", c.Options.WaitReplicas, c.Options.WaitReplicasTimeout) 396 if err == nil { 397 err = c.rdb.Process(ctx, cmd) 398 } 399 var replicas int 400 if err == nil { 401 replicas, err = cmd.Int() 402 } 403 if err == nil && replicas < c.Options.WaitReplicas { 404 err = fmt.Errorf("wait replicas %d failed. result replicas: %d", c.Options.WaitReplicas, replicas) 405 } 406 return err 407 } 408 return luaFn(c.rdb) 409 }