github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/pkg/lock/simple_redis.go (about)

     1  package lock
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"math/rand"
     7  	"strconv"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/cozy/cozy-stack/pkg/logger"
    12  	"github.com/cozy/cozy-stack/pkg/prefixer"
    13  	"github.com/cozy/cozy-stack/pkg/utils"
    14  	"github.com/redis/go-redis/v9"
    15  )
    16  
    17  const luaRefresh = `if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`
    18  const luaRelease = `if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) else return 0 end`
    19  
    20  type subRedisInterface interface {
    21  	SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.BoolCmd
    22  	Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
    23  }
    24  
    25  const (
    26  	basicLockNS = "locks:"
    27  
    28  	// LockValueSize is the size of the random value used to ensure a lock
    29  	// is ours. If two stack were to generate the same value, locks will break.
    30  	lockTokenSize = 20
    31  
    32  	// LockTimeout is the expiration of a redis lock if any operation is longer
    33  	// than this, it should refresh the lock.
    34  	LockTimeout = 20 * time.Second
    35  
    36  	// WaitTimeout is the maximum time to wait before returning control to caller.
    37  	WaitTimeout = 1 * time.Minute
    38  
    39  	// WaitRetry is the time to wait between retries.
    40  	WaitRetry = 100 * time.Millisecond
    41  )
    42  
    43  var (
    44  	// ErrTooManyRetries is the error returned when despite several tries
    45  	// we never managed to get a lock
    46  	ErrTooManyRetries = errors.New("abort after too many failures without getting the lock")
    47  )
    48  
    49  var redislocksMu sync.Mutex
    50  var redisRng *rand.Rand
    51  var redisLogger logger.Logger
    52  
    53  type RedisLockGetter struct {
    54  	client redis.UniversalClient
    55  	locks  *sync.Map
    56  }
    57  
    58  func NewRedisLockGetter(client redis.UniversalClient) *RedisLockGetter {
    59  	redisRng = rand.New(rand.NewSource(time.Now().UnixNano()))
    60  	redisLogger = logger.WithNamespace("redis-lock")
    61  
    62  	return &RedisLockGetter{
    63  		client: client,
    64  		locks:  new(sync.Map),
    65  	}
    66  }
    67  
    68  func (r *RedisLockGetter) ReadWrite(db prefixer.Prefixer, name string) ErrorRWLocker {
    69  	ns := db.DBPrefix() + "/" + name
    70  	lock, _ := r.locks.LoadOrStore(ns, &redisLock{
    71  		client:    r.client,
    72  		ctx:       context.Background(),
    73  		timeout:   LockTimeout,
    74  		waitRetry: WaitRetry,
    75  		key:       basicLockNS + ns,
    76  	})
    77  
    78  	return lock.(*redisLock)
    79  }
    80  
    81  // LongOperation returns a lock suitable for long operations. It will refresh
    82  // the lock in redis to avoid its automatic expiration.
    83  func (r *RedisLockGetter) LongOperation(db prefixer.Prefixer, name string) ErrorLocker {
    84  	return &longOperation{
    85  		lock:    r.ReadWrite(db, name).(*redisLock),
    86  		timeout: LockTimeout,
    87  	}
    88  }
    89  
    90  type redisLock struct {
    91  	client    subRedisInterface
    92  	ctx       context.Context
    93  	mu        sync.Mutex
    94  	timeout   time.Duration
    95  	waitRetry time.Duration
    96  	key       string
    97  	token     string
    98  	// readers is the number of readers when the lock is acquired for reading
    99  	// or -1 when it is locked for writing. 0 means that the lock is free.
   100  	readers int
   101  }
   102  
   103  func (rl *redisLock) Lock() error {
   104  	// Calculate the timestamp we are willing to wait for.
   105  	stop := time.Now().Add(rl.timeout)
   106  
   107  	redislocksMu.Lock()
   108  	token := utils.RandomStringFast(redisRng, lockTokenSize)
   109  	redislocksMu.Unlock()
   110  
   111  	for {
   112  		ok, err := rl.obtainsWriting(token)
   113  		if err != nil || ok {
   114  			return err
   115  		}
   116  		if time.Now().Add(rl.waitRetry).After(stop) {
   117  			return ErrTooManyRetries
   118  		}
   119  		time.Sleep(rl.waitRetry)
   120  	}
   121  }
   122  
   123  func (rl *redisLock) Extend() {
   124  	rl.mu.Lock()
   125  	defer rl.mu.Unlock()
   126  	_, _ = rl.extends()
   127  }
   128  
   129  func (rl *redisLock) RLock() error {
   130  	// Note that the current code does not try to allow two cozy-stacks to
   131  	// share a lock for reading. If one cozy-stack has locked for reading a
   132  	// lock, another cozy-stack will have to wait that the lock has been
   133  	// released before being able to give a lock for reading on the same name.
   134  	// It may be improved, but I prefer to err on the safe side for now. And it
   135  	// still allows to have two readers on the same cozy-stack.
   136  
   137  	stop := time.Now().Add(rl.timeout)
   138  
   139  	redislocksMu.Lock()
   140  	token := utils.RandomStringFast(redisRng, lockTokenSize)
   141  	redislocksMu.Unlock()
   142  
   143  	for {
   144  		ok, err := rl.extendsOrObtainsReading(token)
   145  		if err != nil || ok {
   146  			return err
   147  		}
   148  		if time.Now().Add(rl.waitRetry).After(stop) {
   149  			return ErrTooManyRetries
   150  		}
   151  		time.Sleep(rl.waitRetry)
   152  	}
   153  }
   154  
   155  func (rl *redisLock) obtainsWriting(token string) (bool, error) {
   156  	rl.mu.Lock()
   157  	defer rl.mu.Unlock()
   158  	if rl.readers != 0 {
   159  		return false, nil
   160  	}
   161  	return rl.obtains(true, token)
   162  }
   163  
   164  func (rl *redisLock) extendsOrObtainsReading(token string) (bool, error) {
   165  	rl.mu.Lock()
   166  	defer rl.mu.Unlock()
   167  	if rl.readers < 0 {
   168  		return false, nil
   169  	}
   170  	ok, err := rl.extends()
   171  	if ok {
   172  		rl.readers++
   173  		return true, nil
   174  	}
   175  	if err != nil {
   176  		return false, err
   177  	}
   178  	return rl.obtains(false, token)
   179  }
   180  
   181  func (rl *redisLock) obtains(writing bool, token string) (bool, error) {
   182  	// Try to obtain a lock
   183  	ok, err := rl.client.SetNX(rl.ctx, rl.key, token, rl.timeout).Result()
   184  	if err != nil {
   185  		return false, err // most probably redis connectivity error
   186  	}
   187  	if !ok {
   188  		return false, nil
   189  	}
   190  
   191  	rl.token = token
   192  	if writing {
   193  		rl.readers = -1
   194  	} else {
   195  		rl.readers++
   196  	}
   197  	return true, nil
   198  }
   199  
   200  func (rl *redisLock) extends() (bool, error) {
   201  	if rl.token == "" {
   202  		return false, nil
   203  	}
   204  
   205  	// we already have a lock, attempts to extends it
   206  	ttl := strconv.FormatInt(int64(LockTimeout/time.Millisecond), 10)
   207  	ret, err := rl.client.Eval(rl.ctx, luaRefresh, []string{rl.key}, rl.token, ttl).Result()
   208  	if err != nil {
   209  		return false, err // most probably redis connectivity error
   210  	}
   211  	return ret == int64(1), nil
   212  }
   213  
   214  func (rl *redisLock) Unlock() {
   215  	rl.unlock(true)
   216  }
   217  
   218  func (rl *redisLock) RUnlock() {
   219  	rl.unlock(false)
   220  }
   221  
   222  func (rl *redisLock) unlock(writing bool) {
   223  	rl.mu.Lock()
   224  	defer rl.mu.Unlock()
   225  
   226  	if (writing && rl.readers > 0) || (!writing && rl.readers < 0) {
   227  		redisLogger.Errorf("Invalid unlocking: %v %d (%s)", writing, rl.readers, rl.key)
   228  		return
   229  	}
   230  
   231  	if !writing && rl.readers > 1 {
   232  		rl.readers--
   233  		return
   234  	}
   235  
   236  	_, err := rl.client.Eval(rl.ctx, luaRelease, []string{rl.key}, rl.token).Result()
   237  	if err != nil {
   238  		redisLogger.Warnf("Failed to unlock: %s (%s)", err.Error(), rl.key)
   239  	}
   240  
   241  	rl.readers = 0
   242  	rl.token = ""
   243  }