codeberg.org/gruf/go-mutexes@v1.5.0/map.go (about)

     1  package mutexes
     2  
     3  import (
     4  	"sync"
     5  	"sync/atomic"
     6  	"unsafe"
     7  
     8  	"codeberg.org/gruf/go-mempool"
     9  	"github.com/dolthub/swiss"
    10  )
    11  
    12  const (
    13  	// possible lock types.
    14  	lockTypeRead  = uint8(1) << 0
    15  	lockTypeWrite = uint8(1) << 1
    16  )
    17  
    18  // MutexMap is a structure that allows read / write locking
    19  // per key, performing as you'd expect a map[string]*RWMutex
    20  // to perform, without you needing to worry about deadlocks
    21  // between competing read / write locks and the map's own mutex.
    22  // It uses memory pooling for the internal "mutex" (ish) types
    23  // and performs self-eviction of keys.
    24  //
    25  // Under the hood this is achieved using a single mutex for the
    26  // map, state tracking for individual keys, and some sync.Cond{}
    27  // like structures for sleeping / awaking awaiting goroutines.
    28  type MutexMap struct {
    29  	mapmu  sync.Mutex
    30  	mumap  *swiss.Map[string, *rwmutex]
    31  	mupool mempool.UnsafePool
    32  }
    33  
    34  // checkInit ensures MutexMap is initialized (UNSAFE).
    35  func (mm *MutexMap) checkInit() {
    36  	if mm.mumap == nil {
    37  		mm.mumap = swiss.NewMap[string, *rwmutex](0)
    38  		mm.mupool.DirtyFactor = 256
    39  	}
    40  }
    41  
    42  // Lock acquires a write lock on key in map, returning unlock function.
    43  func (mm *MutexMap) Lock(key string) func() {
    44  	return mm.lock(key, lockTypeWrite)
    45  }
    46  
    47  // RLock acquires a read lock on key in map, returning runlock function.
    48  func (mm *MutexMap) RLock(key string) func() {
    49  	return mm.lock(key, lockTypeRead)
    50  }
    51  
    52  func (mm *MutexMap) lock(key string, lt uint8) func() {
    53  	// Perform first map lock
    54  	// and check initialization
    55  	// OUTSIDE the main loop.
    56  	mm.mapmu.Lock()
    57  	mm.checkInit()
    58  
    59  	for {
    60  		// Check map for mutex.
    61  		mu, _ := mm.mumap.Get(key)
    62  
    63  		if mu == nil {
    64  			// Allocate mutex.
    65  			mu = mm.acquire()
    66  			mm.mumap.Put(key, mu)
    67  		}
    68  
    69  		if !mu.Lock(lt) {
    70  			// Wait on mutex unlock, after
    71  			// immediately relocking map mu.
    72  			mu.WaitRelock(&mm.mapmu)
    73  			continue
    74  		}
    75  
    76  		// Done with map.
    77  		mm.mapmu.Unlock()
    78  
    79  		// Return mutex unlock function.
    80  		return func() { mm.unlock(key, mu) }
    81  	}
    82  }
    83  
    84  func (mm *MutexMap) unlock(key string, mu *rwmutex) {
    85  	// Get map lock.
    86  	mm.mapmu.Lock()
    87  
    88  	// Unlock mutex.
    89  	if !mu.Unlock() {
    90  
    91  		// Fast path. Mutex still
    92  		// used so no map change.
    93  		mm.mapmu.Unlock()
    94  		return
    95  	}
    96  
    97  	// Mutex fully unlocked
    98  	// with zero waiters. Self
    99  	// evict and release it.
   100  	mm.mumap.Delete(key)
   101  	mm.release(mu)
   102  
   103  	// Maximum load factor before
   104  	// 'swiss' allocates new hmap:
   105  	// maxLoad = 7 / 8
   106  	//
   107  	// So we apply the inverse/2, once
   108  	// $maxLoad/2 % of hmap is empty we
   109  	// compact the map to drop buckets.
   110  	len := mm.mumap.Count()
   111  	cap := mm.mumap.Capacity()
   112  	if cap-len > (cap*7)/(8*2) {
   113  
   114  		// Create a new map only as big as required.
   115  		mumap := swiss.NewMap[string, *rwmutex](uint32(len))
   116  		mm.mumap.Iter(func(k string, v *rwmutex) (stop bool) {
   117  			mumap.Put(k, v)
   118  			return false
   119  		})
   120  
   121  		// Set new map.
   122  		mm.mumap = mumap
   123  	}
   124  
   125  	// Done with map.
   126  	mm.mapmu.Unlock()
   127  }
   128  
   129  // acquire will acquire mutex from memory pool, or alloc new.
   130  func (mm *MutexMap) acquire() *rwmutex {
   131  	if ptr := mm.mupool.Get(); ptr != nil {
   132  		return (*rwmutex)(ptr)
   133  	}
   134  	return new(rwmutex)
   135  }
   136  
   137  // release will release given mutex to memory pool.
   138  func (mm *MutexMap) release(mu *rwmutex) {
   139  	ptr := unsafe.Pointer(mu)
   140  	mm.mupool.Put(ptr)
   141  }
   142  
   143  // rwmutex represents a RW mutex when used correctly within
   144  // a MapMutex. It should ONLY be access when protected by
   145  // the outer map lock, except for the 'notifyList' which is
   146  // a runtime internal structure borrowed from the sync.Cond{}.
   147  //
   148  // this functions very similarly to a sync.Cond{}, but with
   149  // lock state tracking, and returning on 'Broadcast()' whether
   150  // any goroutines were actually awoken. it also has a less
   151  // confusing API than sync.Cond{} with the outer locking
   152  // mechanism we use, otherwise all Cond{}.L would reference
   153  // the same outer map mutex.
   154  type rwmutex struct {
   155  	n notifyList // 'trigger' mechanism
   156  	l int32      // no. locks
   157  	t uint8      // lock type
   158  }
   159  
   160  // Lock will lock the mutex for given lock type, in the
   161  // sense that it will update the internal state tracker
   162  // accordingly. Return value is true on successful lock.
   163  func (mu *rwmutex) Lock(lt uint8) bool {
   164  	switch mu.t {
   165  	case lockTypeRead:
   166  		// already read locked,
   167  		// only permit more reads.
   168  		if lt != lockTypeRead {
   169  			return false
   170  		}
   171  
   172  	case lockTypeWrite:
   173  		// already write locked,
   174  		// no other locks allowed.
   175  		return false
   176  
   177  	default:
   178  		// Fully unlocked,
   179  		// set incoming type.
   180  		mu.t = lt
   181  	}
   182  
   183  	// Update
   184  	// count.
   185  	mu.l++
   186  
   187  	return true
   188  }
   189  
   190  // Unlock will unlock the mutex, in the sense that it
   191  // will update the internal state tracker accordingly.
   192  // On totally unlocked state, it will awaken all
   193  // sleeping goroutines waiting on this mutex.
   194  func (mu *rwmutex) Unlock() bool {
   195  	switch mu.l--; {
   196  	case mu.l > 0 && mu.t == lockTypeWrite:
   197  		panic("BUG: multiple writer locks")
   198  	case mu.l < 0:
   199  		panic("BUG: negative lock count")
   200  
   201  	case mu.l == 0:
   202  		// Fully unlocked.
   203  		mu.t = 0
   204  
   205  		// Awake all blocked goroutines and check
   206  		// for change in the last notified ticket.
   207  		before := atomic.LoadUint32(&mu.n.notify)
   208  		runtime_notifyListNotifyAll(&mu.n)
   209  		after := atomic.LoadUint32(&mu.n.notify)
   210  
   211  		// If ticket changed, this indicates
   212  		// AT LEAST one goroutine was awoken.
   213  		//
   214  		// (before != after) => (waiters > 0)
   215  		// (before == after) => (waiters = 0)
   216  		return (before == after)
   217  
   218  	default:
   219  		// i.e. mutex still
   220  		// locked by others.
   221  		return false
   222  	}
   223  }
   224  
   225  // WaitRelock expects a mutex to be passed in, already in the
   226  // locked state. It incr the notifyList waiter count before
   227  // unlocking the outer mutex and blocking on notifyList wait.
   228  // On awake it will decr wait count and relock outer mutex.
   229  func (mu *rwmutex) WaitRelock(outer *sync.Mutex) {
   230  
   231  	// add ourselves to list while still
   232  	// under protection of outer map lock.
   233  	t := runtime_notifyListAdd(&mu.n)
   234  
   235  	// Finished with
   236  	// outer map lock.
   237  	outer.Unlock()
   238  
   239  	// Block until awoken by another
   240  	// goroutine within mu.Unlock().
   241  	runtime_notifyListWait(&mu.n, t)
   242  
   243  	// Relock!
   244  	outer.Lock()
   245  }