github.com/wfusion/gofusion@v1.1.14/lock/mysql.go (about)

     1  package lock
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/rand"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/pkg/errors"
    11  
    12  	"github.com/wfusion/gofusion/common/utils"
    13  	"github.com/wfusion/gofusion/config"
    14  	"github.com/wfusion/gofusion/db"
    15  	"github.com/wfusion/gofusion/routine"
    16  )
    17  
    18  const (
    19  	mysqlLockSQL   = "SELECT GET_LOCK(?, ?)"
    20  	mysqlUnlockSQL = "DO RELEASE_LOCK(?)"
    21  )
    22  
    23  type mysqlLocker struct {
    24  	ctx     context.Context
    25  	dbName  string
    26  	appName string
    27  
    28  	locker     sync.RWMutex
    29  	lockTimers map[string]struct{}
    30  }
    31  
    32  func newMysqlLocker(ctx context.Context, appName, dbName string) Lockable {
    33  	return &mysqlLocker{ctx: ctx, appName: appName, dbName: dbName, lockTimers: map[string]struct{}{}}
    34  }
    35  
    36  func (m *mysqlLocker) Lock(ctx context.Context, key string, opts ...utils.OptionExtender) (err error) {
    37  	opt := utils.ApplyOptions[lockOption](opts...)
    38  	expired := tolerance
    39  	if opt.expired > 0 {
    40  		expired = opt.expired
    41  	}
    42  	lockKey := m.formatLockKey(key)
    43  	if len(lockKey) > 64 {
    44  		return errors.Errorf("key %s length is too long, max key length is 64", lockKey)
    45  	}
    46  
    47  	m.locker.Lock()
    48  	defer m.locker.Unlock()
    49  	// disable reentrant
    50  	if _, ok := m.lockTimers[lockKey]; ok {
    51  		return ErrTimeout
    52  	}
    53  
    54  	ret := db.Use(ctx, m.dbName, db.AppName(m.appName)).Raw(mysqlLockSQL, lockKey, 0)
    55  	if err = ret.Error; err != nil {
    56  		return ret.Error
    57  	}
    58  	var result int64
    59  	if err = ret.Scan(&result).Error; err != nil {
    60  		return
    61  	}
    62  	if result != 1 {
    63  		return ErrTimeout
    64  	}
    65  
    66  	// expire loop
    67  	m.lockTimers[lockKey] = struct{}{}
    68  	timer := time.NewTimer(expired)
    69  	routine.Loop(
    70  		func(ctx context.Context, key string, timer *time.Timer) {
    71  			defer timer.Stop()
    72  
    73  			lockKey := m.formatLockKey(key)
    74  			if !m.isLocked(ctx, lockKey) {
    75  				return
    76  			}
    77  
    78  			for {
    79  				select {
    80  				case <-ctx.Done():
    81  					_ = m.Unlock(ctx, key) // context done
    82  					return
    83  				case <-m.ctx.Done():
    84  					_ = m.Unlock(ctx, key) // context done
    85  					return
    86  				case <-timer.C:
    87  					_ = m.Unlock(ctx, key) // timeout
    88  					return
    89  				default:
    90  					if !m.isLocked(ctx, lockKey) {
    91  						return
    92  					}
    93  					time.Sleep(200*time.Millisecond + time.Duration(rand.Int63())%(100*time.Millisecond))
    94  				}
    95  			}
    96  		}, routine.Args(ctx, key, timer), routine.AppName(m.appName))
    97  	return
    98  }
    99  
   100  func (m *mysqlLocker) Unlock(ctx context.Context, key string, _ ...utils.OptionExtender) (err error) {
   101  	lockKey := m.formatLockKey(key)
   102  	if err = db.Use(ctx, m.dbName, db.AppName(m.appName)).Raw(mysqlUnlockSQL, lockKey).Error; err != nil {
   103  		return
   104  	}
   105  	m.locker.Lock()
   106  	defer m.locker.Unlock()
   107  	delete(m.lockTimers, lockKey)
   108  
   109  	return
   110  }
   111  
   112  func (m *mysqlLocker) isLocked(ctx context.Context, lockKey string) (locked bool) {
   113  	m.locker.RLock()
   114  	defer m.locker.RUnlock()
   115  	_, locked = m.lockTimers[lockKey]
   116  	return
   117  }
   118  
   119  func (m *mysqlLocker) formatLockKey(key string) string {
   120  	return fmt.Sprintf("%s:%s", config.Use(m.appName).AppName(), key)
   121  }