github.com/dtm-labs/rockscache@v0.1.1/batch.go (about)

     1  package rockscache
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand"
     8  	"runtime/debug"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/lithammer/shortuuid"
    13  	"github.com/redis/go-redis/v9"
    14  )
    15  
    16  var (
    17  	errNeedFetch      = errors.New("need fetch")
    18  	errNeedAsyncFetch = errors.New("need async fetch")
    19  )
    20  
    21  func (c *Client) luaGetBatch(ctx context.Context, keys []string, owner string) ([]interface{}, error) {
    22  	res, err := callLua(ctx, c.rdb, `-- luaGetBatch
    23      local rets = {}
    24      for i, key in ipairs(KEYS)
    25      do
    26          local v = redis.call('HGET', key, 'value')
    27          local lu = redis.call('HGET', key, 'lockUntil')
    28          if lu ~= false and tonumber(lu) < tonumber(ARGV[1]) or lu == false and v == false then
    29              redis.call('HSET', key, 'lockUntil', ARGV[2])
    30              redis.call('HSET', key, 'lockOwner', ARGV[3])
    31              table.insert(rets, { v, 'LOCKED' })
    32          else
    33              table.insert(rets, {v, lu})
    34          end
    35      end
    36      return rets
    37  	`, keys, []interface{}{now(), now() + int64(c.Options.LockExpire/time.Second), owner})
    38  	debugf("luaGetBatch return: %v, %v", res, err)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  	return res.([]interface{}), nil
    43  }
    44  
    45  func (c *Client) luaSetBatch(ctx context.Context, keys []string, values []string, expires []int, owner string) error {
    46  	var vals = make([]interface{}, 0, 2+len(values))
    47  	vals = append(vals, owner)
    48  	for _, v := range values {
    49  		vals = append(vals, v)
    50  	}
    51  	for _, ex := range expires {
    52  		vals = append(vals, ex)
    53  	}
    54  	_, err := callLua(ctx, c.rdb, `-- luaSetBatch
    55      local n = #KEYS
    56      for i, key in ipairs(KEYS)
    57      do
    58          local o = redis.call('HGET', key, 'lockOwner')
    59          if o ~= ARGV[1] then
    60                  return
    61          end
    62          redis.call('HSET', key, 'value', ARGV[i+1])
    63          redis.call('HDEL', key, 'lockUntil')
    64          redis.call('HDEL', key, 'lockOwner')
    65          redis.call('EXPIRE', key, ARGV[i+1+n])
    66      end
    67  	`, keys, vals)
    68  	return err
    69  }
    70  
    71  func (c *Client) fetchBatch(ctx context.Context, keys []string, idxs []int, expire time.Duration, owner string, fn func(idxs []int) (map[int]string, error)) (map[int]string, error) {
    72  	defer func() {
    73  		if r := recover(); r != nil {
    74  			debug.PrintStack()
    75  		}
    76  	}()
    77  	data, err := fn(idxs)
    78  	if err != nil {
    79  		for _, idx := range idxs {
    80  			_ = c.UnlockForUpdate(ctx, keys[idx], owner)
    81  		}
    82  		return nil, err
    83  	}
    84  
    85  	if data == nil {
    86  		// incase data is nil
    87  		data = make(map[int]string)
    88  	}
    89  
    90  	var batchKeys []string
    91  	var batchValues []string
    92  	var batchExpires []int
    93  
    94  	for _, idx := range idxs {
    95  		v := data[idx]
    96  		ex := expire - c.Options.Delay - time.Duration(rand.Float64()*c.Options.RandomExpireAdjustment*float64(expire))
    97  		if v == "" {
    98  			if c.Options.EmptyExpire == 0 { // if empty expire is 0, then delete the key
    99  				_ = c.rdb.Del(ctx, keys[idx]).Err()
   100  				if err != nil {
   101  					debugf("batch: del failed key=%s err:%s", keys[idx], err.Error())
   102  				}
   103  				continue
   104  			}
   105  			ex = c.Options.EmptyExpire
   106  
   107  			data[idx] = v // incase idx not in data
   108  		}
   109  		batchKeys = append(batchKeys, keys[idx])
   110  		batchValues = append(batchValues, v)
   111  		batchExpires = append(batchExpires, int(ex/time.Second))
   112  	}
   113  
   114  	err = c.luaSetBatch(ctx, batchKeys, batchValues, batchExpires, owner)
   115  	if err != nil {
   116  		debugf("batch: luaSetBatch failed keys=%s err:%s", keys, err.Error())
   117  	}
   118  	return data, nil
   119  }
   120  
   121  func (c *Client) keysIdx(keys []string) (idxs []int) {
   122  	for i := range keys {
   123  		idxs = append(idxs, i)
   124  	}
   125  	return idxs
   126  }
   127  
   128  type pair struct {
   129  	idx  int
   130  	data string
   131  	err  error
   132  }
   133  
   134  func (c *Client) weakFetchBatch(ctx context.Context, keys []string, expire time.Duration, fn func(idxs []int) (map[int]string, error)) (map[int]string, error) {
   135  	debugf("batch: weakFetch keys=%+v", keys)
   136  	var result = make(map[int]string)
   137  	owner := shortuuid.New()
   138  	var toGet, toFetch, toFetchAsync []int
   139  
   140  	// read from redis without sleep
   141  	rs, err := c.luaGetBatch(ctx, keys, owner)
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  	for i, v := range rs {
   146  		r := v.([]interface{})
   147  
   148  		if r[0] == nil {
   149  			if r[1] == locked {
   150  				toFetch = append(toFetch, i)
   151  			} else {
   152  				toGet = append(toGet, i)
   153  			}
   154  			continue
   155  		}
   156  
   157  		if r[1] == locked {
   158  			toFetchAsync = append(toFetchAsync, i)
   159  			// fallthrough with old data
   160  		} // else new data
   161  
   162  		result[i] = r[0].(string)
   163  	}
   164  
   165  	if len(toFetchAsync) > 0 {
   166  		go func(idxs []int) {
   167  			debugf("batch weak: async fetch keys=%+v", keys)
   168  			_, _ = c.fetchBatch(ctx, keys, idxs, expire, owner, fn)
   169  		}(toFetchAsync)
   170  		toFetchAsync = toFetchAsync[:0] // reset toFetch
   171  	}
   172  
   173  	if len(toFetch) > 0 {
   174  		// batch fetch
   175  		fetched, err := c.fetchBatch(ctx, keys, toFetch, expire, owner, fn)
   176  		if err != nil {
   177  			return nil, err
   178  		}
   179  		for _, k := range toFetch {
   180  			result[k] = fetched[k]
   181  		}
   182  		toFetch = toFetch[:0] // reset toFetch
   183  	}
   184  
   185  	if len(toGet) > 0 {
   186  		// read from redis and sleep to wait
   187  		var wg sync.WaitGroup
   188  
   189  		var ch = make(chan pair, len(toGet))
   190  		for _, idx := range toGet {
   191  			wg.Add(1)
   192  			go func(i int) {
   193  				defer wg.Done()
   194  				r, err := c.luaGet(ctx, keys[i], owner)
   195  				for err == nil && r[0] == nil && r[1].(string) != locked {
   196  					debugf("batch weak: empty result for %s locked by other, so sleep %s", keys[i], c.Options.LockSleep.String())
   197  					time.Sleep(c.Options.LockSleep)
   198  					r, err = c.luaGet(ctx, keys[i], owner)
   199  				}
   200  				if err != nil {
   201  					ch <- pair{idx: i, data: "", err: err}
   202  					return
   203  				}
   204  				if r[1] != locked { // normal value
   205  					ch <- pair{idx: i, data: r[0].(string), err: nil}
   206  					return
   207  				}
   208  				if r[0] == nil {
   209  					ch <- pair{idx: i, data: "", err: errNeedFetch}
   210  					return
   211  				}
   212  				ch <- pair{idx: i, data: "", err: errNeedAsyncFetch}
   213  			}(idx)
   214  		}
   215  		wg.Wait()
   216  		close(ch)
   217  
   218  		for p := range ch {
   219  			if p.err != nil {
   220  				switch p.err {
   221  				case errNeedFetch:
   222  					toFetch = append(toFetch, p.idx)
   223  					continue
   224  				case errNeedAsyncFetch:
   225  					toFetchAsync = append(toFetchAsync, p.idx)
   226  					continue
   227  				default:
   228  				}
   229  				return nil, p.err
   230  			}
   231  			result[p.idx] = p.data
   232  		}
   233  	}
   234  
   235  	if len(toFetchAsync) > 0 {
   236  		go func(idxs []int) {
   237  			debugf("batch weak: async 2 fetch keys=%+v", keys)
   238  			_, _ = c.fetchBatch(ctx, keys, idxs, expire, owner, fn)
   239  		}(toFetchAsync)
   240  	}
   241  
   242  	if len(toFetch) > 0 {
   243  		// batch fetch
   244  		fetched, err := c.fetchBatch(ctx, keys, toFetch, expire, owner, fn)
   245  		if err != nil {
   246  			return nil, err
   247  		}
   248  		for _, k := range toFetch {
   249  			result[k] = fetched[k]
   250  		}
   251  	}
   252  
   253  	return result, nil
   254  }
   255  
   256  func (c *Client) strongFetchBatch(ctx context.Context, keys []string, expire time.Duration, fn func(idxs []int) (map[int]string, error)) (map[int]string, error) {
   257  	debugf("batch: strongFetch keys=%+v", keys)
   258  	var result = make(map[int]string)
   259  	owner := shortuuid.New()
   260  	var toGet, toFetch []int
   261  
   262  	// read from redis without sleep
   263  	rs, err := c.luaGetBatch(ctx, keys, owner)
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  	for i, v := range rs {
   268  		r := v.([]interface{})
   269  		if r[1] == nil { // normal value
   270  			result[i] = r[0].(string)
   271  			continue
   272  		}
   273  
   274  		if r[1] != locked { // locked by other
   275  			debugf("batch: locked by other, continue idx=%d", i)
   276  			toGet = append(toGet, i)
   277  			continue
   278  		}
   279  
   280  		// locked for fetch
   281  		toFetch = append(toFetch, i)
   282  	}
   283  
   284  	if len(toFetch) > 0 {
   285  		// batch fetch
   286  		fetched, err := c.fetchBatch(ctx, keys, toFetch, expire, owner, fn)
   287  		if err != nil {
   288  			return nil, err
   289  		}
   290  		for _, k := range toFetch {
   291  			result[k] = fetched[k]
   292  		}
   293  		toFetch = toFetch[:0] // reset toFetch
   294  	}
   295  
   296  	if len(toGet) > 0 {
   297  		// read from redis and sleep to wait
   298  		var wg sync.WaitGroup
   299  		var ch = make(chan pair, len(toGet))
   300  		for _, idx := range toGet {
   301  			wg.Add(1)
   302  			go func(i int) {
   303  				defer wg.Done()
   304  				r, err := c.luaGet(ctx, keys[i], owner)
   305  				for err == nil && r[1] != nil && r[1] != locked { // locked by other
   306  					debugf("batch: locked by other, so sleep %s", c.Options.LockSleep)
   307  					time.Sleep(c.Options.LockSleep)
   308  					r, err = c.luaGet(ctx, keys[i], owner)
   309  				}
   310  				if err != nil {
   311  					ch <- pair{idx: i, data: "", err: err}
   312  					return
   313  				}
   314  				if r[1] != locked { // normal value
   315  					ch <- pair{idx: i, data: r[0].(string), err: nil}
   316  					return
   317  				}
   318  				// locked for update
   319  				ch <- pair{idx: i, data: "", err: errNeedFetch}
   320  			}(idx)
   321  		}
   322  		wg.Wait()
   323  		close(ch)
   324  		for p := range ch {
   325  			if p.err != nil {
   326  				if p.err == errNeedFetch {
   327  					toFetch = append(toFetch, p.idx)
   328  					continue
   329  				}
   330  				return nil, p.err
   331  			}
   332  			result[p.idx] = p.data
   333  		}
   334  	}
   335  
   336  	if len(toFetch) > 0 {
   337  		// batch fetch
   338  		fetched, err := c.fetchBatch(ctx, keys, toFetch, expire, owner, fn)
   339  		if err != nil {
   340  			return nil, err
   341  		}
   342  		for _, k := range toFetch {
   343  			result[k] = fetched[k]
   344  		}
   345  	}
   346  
   347  	return result, nil
   348  }
   349  
   350  // FetchBatch returns a map with values indexed by index of keys list.
   351  // 1. the first parameter is the keys list of the data
   352  // 2. the second parameter is the data expiration time
   353  // 3. the third parameter is the batch data fetch function which is called when the cache does not exist
   354  // the parameter of the batch data fetch function is the index list of those keys
   355  // missing in cache, which can be used to form a batch query for missing data.
   356  // the return value of the batch data fetch function is a map, with key of the
   357  // index and value of the corresponding data in form of string
   358  func (c *Client) FetchBatch(keys []string, expire time.Duration, fn func(idxs []int) (map[int]string, error)) (map[int]string, error) {
   359  	return c.FetchBatch2(c.Options.Context, keys, expire, fn)
   360  }
   361  
   362  // FetchBatch2 is same with FetchBatch, except that a user defined context.Context can be provided.
   363  func (c *Client) FetchBatch2(ctx context.Context, keys []string, expire time.Duration, fn func(idxs []int) (map[int]string, error)) (map[int]string, error) {
   364  	if c.Options.DisableCacheRead {
   365  		return fn(c.keysIdx(keys))
   366  	} else if c.Options.StrongConsistency {
   367  		return c.strongFetchBatch(ctx, keys, expire, fn)
   368  	}
   369  	return c.weakFetchBatch(ctx, keys, expire, fn)
   370  }
   371  
   372  // TagAsDeletedBatch a key list, the keys in list will expire after delay time.
   373  func (c *Client) TagAsDeletedBatch(keys []string) error {
   374  	return c.TagAsDeletedBatch2(c.Options.Context, keys)
   375  }
   376  
   377  // TagAsDeletedBatch2 a key list, the keys in list will expire after delay time.
   378  func (c *Client) TagAsDeletedBatch2(ctx context.Context, keys []string) error {
   379  	if c.Options.DisableCacheDelete {
   380  		return nil
   381  	}
   382  	debugf("batch deleting: keys=%v", keys)
   383  	luaFn := func(con redisConn) error {
   384  		_, err := callLua(ctx, con, ` -- luaDeleteBatch
   385  		for i, key in ipairs(KEYS) do
   386  			redis.call('HSET', key, 'lockUntil', 0)
   387  			redis.call('HDEL', key, 'lockOwner')
   388  			redis.call('EXPIRE', key, ARGV[1])
   389  		end
   390  		`, keys, []interface{}{int64(c.Options.Delay / time.Second)})
   391  		return err
   392  	}
   393  	if c.Options.WaitReplicas > 0 {
   394  		err := luaFn(c.rdb)
   395  		cmd := redis.NewCmd(ctx, "WAIT", c.Options.WaitReplicas, c.Options.WaitReplicasTimeout)
   396  		if err == nil {
   397  			err = c.rdb.Process(ctx, cmd)
   398  		}
   399  		var replicas int
   400  		if err == nil {
   401  			replicas, err = cmd.Int()
   402  		}
   403  		if err == nil && replicas < c.Options.WaitReplicas {
   404  			err = fmt.Errorf("wait replicas %d failed. result replicas: %d", c.Options.WaitReplicas, replicas)
   405  		}
   406  		return err
   407  	}
   408  	return luaFn(c.rdb)
   409  }