github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kbfs/kbfssync/leveled_mutex.go (about)

     1  // Copyright 2016 Keybase Inc. All rights reserved.
     2  // Use of this source code is governed by a BSD
     3  // license that can be found in the LICENSE file.
     4  
     5  package kbfssync
     6  
     7  import (
     8  	"fmt"
     9  	"sync"
    10  	"sync/atomic"
    11  )
    12  
    13  // The LeveledMutex, LeveledRWMutex, and LockState types enables a
    14  // lock hierarchy to be checked. For a program (or subsystem), each
    15  // (rw-)mutex must have a unique associated MutexLevel, which means
    16  // that a (rw-)mutex must not be (r-)locked before another (rw-)mutex
    17  // with a lower MutexLevel in a given execution flow. This is achieved
    18  // by creating a new LockState at the start of an execution flow and
    19  // passing it to the (r-)lock/(r-)unlock methods of each (rw-)mutex.
    20  //
    21  // TODO: Once this becomes a bottleneck, add a +build production
    22  // version that stubs everything out.
    23  
    24  // An exclusiveLock is a lock around something that is expected to be
    25  // accessed exclusively. It immediately panics upon any lock
    26  // contention.
    27  type exclusiveLock struct {
    28  	v *int32
    29  }
    30  
    31  func makeExclusiveLock() exclusiveLock {
    32  	return exclusiveLock{
    33  		v: new(int32),
    34  	}
    35  }
    36  
    37  func (l exclusiveLock) lock() {
    38  	if !atomic.CompareAndSwapInt32(l.v, 0, 1) {
    39  		panic("unexpected concurrent access")
    40  	}
    41  }
    42  
    43  func (l exclusiveLock) unlock() {
    44  	if !atomic.CompareAndSwapInt32(l.v, 1, 0) {
    45  		panic("unexpected concurrent access")
    46  	}
    47  }
    48  
    49  // MutexLevel is the level for a mutex, which must be unique to that
    50  // mutex.
    51  type MutexLevel int
    52  
    53  // exclusionType is the type of exclusion of a lock. A regular lock
    54  // always uses write exclusion, where only one thing at a time can
    55  // hold the lock, whereas a reader-writer lock can do either write
    56  // exclusion or read exclusion, where only one writer or any number of
    57  // readers can hold the lock.
    58  type exclusionType int
    59  
    60  const (
    61  	nonExclusion   exclusionType = 0
    62  	writeExclusion exclusionType = 1
    63  	readExclusion  exclusionType = 2
    64  )
    65  
    66  func (et exclusionType) prefix() string {
    67  	switch et {
    68  	case nonExclusion:
    69  		return "Un"
    70  	case writeExclusion:
    71  		return ""
    72  	case readExclusion:
    73  		return "R"
    74  	}
    75  	return fmt.Sprintf("exclusionType{%d}", et)
    76  }
    77  
    78  // exclusionState holds the state for a held mutex.
    79  type exclusionState struct {
    80  	// The level of the held mutex.
    81  	level MutexLevel
    82  	// The exclusion type of the held mutex.
    83  	exclusionType exclusionType
    84  }
    85  
    86  // LockState holds the info regarding which level mutexes are held or
    87  // not for a particular execution flow.
    88  type LockState struct {
    89  	levelToString func(MutexLevel) string
    90  
    91  	// Protects exclusionStates.
    92  	exclusionStatesLock exclusiveLock
    93  	// The stack of held mutexes, ordered by increasing level.
    94  	exclusionStates []exclusionState
    95  }
    96  
    97  // MakeLevelState returns a new LockState. This must be called at the
    98  // start of a new execution flow and passed to any LeveledMutex or
    99  // LeveledRWMutex operation during that execution flow.
   100  //
   101  // TODO: Consider adding a parameter to set the capacity of
   102  // exclusionStates.
   103  func MakeLevelState(levelToString func(MutexLevel) string) *LockState {
   104  	return &LockState{
   105  		levelToString:       levelToString,
   106  		exclusionStatesLock: makeExclusiveLock(),
   107  	}
   108  }
   109  
   110  // currLocked returns the current exclusion state, or nil if there is
   111  // none.
   112  func (state *LockState) currLocked() *exclusionState {
   113  	stateCount := len(state.exclusionStates)
   114  	if stateCount == 0 {
   115  		return nil
   116  	}
   117  	return &state.exclusionStates[stateCount-1]
   118  }
   119  
   120  type levelViolationError struct {
   121  	levelToString func(MutexLevel) string
   122  	level         MutexLevel
   123  	exclusionType exclusionType
   124  	curr          exclusionState
   125  }
   126  
   127  func (e levelViolationError) Error() string {
   128  	return fmt.Sprintf("level violation: %s %sLocked after %s %sLocked",
   129  		e.levelToString(e.level), e.exclusionType.prefix(),
   130  		e.levelToString(e.curr.level), e.curr.exclusionType.prefix())
   131  }
   132  
   133  func (state *LockState) doLock(
   134  	level MutexLevel, exclusionType exclusionType, lock sync.Locker) error {
   135  	state.exclusionStatesLock.lock()
   136  	defer state.exclusionStatesLock.unlock()
   137  
   138  	curr := state.currLocked()
   139  
   140  	if curr != nil && level <= curr.level {
   141  		return levelViolationError{
   142  			levelToString: state.levelToString,
   143  			level:         level,
   144  			exclusionType: exclusionType,
   145  			curr:          *curr,
   146  		}
   147  	}
   148  
   149  	lock.Lock()
   150  
   151  	state.exclusionStates = append(state.exclusionStates, exclusionState{
   152  		level:         level,
   153  		exclusionType: exclusionType,
   154  	})
   155  	return nil
   156  }
   157  
   158  type danglingUnlockError struct {
   159  	levelToString func(MutexLevel) string
   160  	level         MutexLevel
   161  	exclusionType exclusionType
   162  }
   163  
   164  func (e danglingUnlockError) Error() string {
   165  	return fmt.Sprintf("%s %sUnlocked while already unlocked",
   166  		e.levelToString(e.level), e.exclusionType.prefix())
   167  }
   168  
   169  type mismatchedUnlockError struct {
   170  	levelToString func(MutexLevel) string
   171  	level         MutexLevel
   172  	exclusionType exclusionType
   173  	curr          exclusionState
   174  }
   175  
   176  func (e mismatchedUnlockError) Error() string {
   177  	return fmt.Sprintf(
   178  		"%sUnlock call for %s doesn't match %sLock call for %s",
   179  		e.exclusionType.prefix(), e.levelToString(e.level),
   180  		e.curr.exclusionType.prefix(), e.levelToString(e.curr.level))
   181  }
   182  
   183  func (state *LockState) doUnlock(
   184  	level MutexLevel, exclusionType exclusionType, lock sync.Locker) error {
   185  	state.exclusionStatesLock.lock()
   186  	defer state.exclusionStatesLock.unlock()
   187  
   188  	curr := state.currLocked()
   189  
   190  	if curr == nil {
   191  		return danglingUnlockError{
   192  			levelToString: state.levelToString,
   193  			level:         level,
   194  			exclusionType: exclusionType,
   195  		}
   196  	}
   197  
   198  	if level != curr.level || curr.exclusionType != exclusionType {
   199  		return mismatchedUnlockError{
   200  			levelToString: state.levelToString,
   201  			level:         level,
   202  			exclusionType: exclusionType,
   203  			curr:          *curr,
   204  		}
   205  	}
   206  
   207  	lock.Unlock()
   208  
   209  	state.exclusionStates = state.exclusionStates[:len(state.exclusionStates)-1]
   210  	return nil
   211  }
   212  
   213  // getExclusionType returns returns the exclusionType for the given
   214  // MutexLevel, or nonExclusion if there is none.
   215  func (state *LockState) getExclusionType(level MutexLevel) exclusionType {
   216  	state.exclusionStatesLock.lock()
   217  	defer state.exclusionStatesLock.unlock()
   218  
   219  	// Not worth it to do anything more complicated than a
   220  	// brute-force search.
   221  	for _, state := range state.exclusionStates {
   222  		if state.level > level {
   223  			break
   224  		}
   225  		if state.level == level {
   226  			return state.exclusionType
   227  		}
   228  	}
   229  
   230  	return nonExclusion
   231  }
   232  
   233  // LeveledMutex is a mutex with an associated level, which must be
   234  // unique. Note that unlike sync.Mutex, LeveledMutex is a reference
   235  // type and not a value type.
   236  type LeveledMutex struct {
   237  	level  MutexLevel
   238  	locker sync.Locker
   239  }
   240  
   241  // MakeLeveledMutex makes a mutex with the given level, backed by the
   242  // given locker.
   243  func MakeLeveledMutex(level MutexLevel, locker sync.Locker) LeveledMutex {
   244  	return LeveledMutex{
   245  		level:  level,
   246  		locker: locker,
   247  	}
   248  }
   249  
   250  // Lock locks the associated locker.
   251  func (m LeveledMutex) Lock(lockState *LockState) {
   252  	err := lockState.doLock(m.level, writeExclusion, m.locker)
   253  	if err != nil {
   254  		panic(err)
   255  	}
   256  }
   257  
   258  // Unlock locks the associated locker.
   259  func (m LeveledMutex) Unlock(lockState *LockState) {
   260  	err := lockState.doUnlock(m.level, writeExclusion, m.locker)
   261  	if err != nil {
   262  		panic(err)
   263  	}
   264  }
   265  
   266  type unexpectedExclusionError struct {
   267  	levelToString func(MutexLevel) string
   268  	level         MutexLevel
   269  	exclusionType exclusionType
   270  }
   271  
   272  func (e unexpectedExclusionError) Error() string {
   273  	return fmt.Sprintf("%s unexpectedly %sLocked",
   274  		e.levelToString(e.level), e.exclusionType.prefix())
   275  }
   276  
   277  // AssertUnlocked does nothing if m is unlocked with respect to the
   278  // given LockState. Otherwise, it panics.
   279  func (m LeveledMutex) AssertUnlocked(lockState *LockState) {
   280  	et := lockState.getExclusionType(m.level)
   281  	if et != nonExclusion {
   282  		panic(unexpectedExclusionError{
   283  			levelToString: lockState.levelToString,
   284  			level:         m.level,
   285  			exclusionType: et,
   286  		})
   287  	}
   288  }
   289  
   290  type unexpectedExclusionTypeError struct {
   291  	levelToString         func(MutexLevel) string
   292  	level                 MutexLevel
   293  	expectedExclusionType exclusionType
   294  	exclusionType         exclusionType
   295  }
   296  
   297  func (e unexpectedExclusionTypeError) Error() string {
   298  	return fmt.Sprintf(
   299  		"%s unexpectedly not %sLocked; instead it is %sLocked",
   300  		e.levelToString(e.level),
   301  		e.expectedExclusionType.prefix(),
   302  		e.exclusionType.prefix())
   303  }
   304  
   305  // AssertLocked does nothing if m is locked with respect to the given
   306  // LockState. Otherwise, it panics.
   307  func (m LeveledMutex) AssertLocked(lockState *LockState) {
   308  	et := lockState.getExclusionType(m.level)
   309  	if et != writeExclusion {
   310  		panic(unexpectedExclusionTypeError{
   311  			levelToString:         lockState.levelToString,
   312  			level:                 m.level,
   313  			expectedExclusionType: writeExclusion,
   314  			exclusionType:         et,
   315  		})
   316  	}
   317  }
   318  
   319  // LeveledLocker represents an object that can be locked and unlocked
   320  // with a LockState.
   321  type LeveledLocker interface {
   322  	Lock(*LockState)
   323  	Unlock(*LockState)
   324  }
   325  
   326  // LeveledRWMutex is a reader-writer mutex with an associated level,
   327  // which must be unique. Note that unlike sync.RWMutex, LeveledRWMutex
   328  // is a reference type and not a value type.
   329  type LeveledRWMutex struct {
   330  	level    MutexLevel
   331  	rwLocker rwLocker
   332  }
   333  
   334  // MakeLeveledRWMutex makes a reader-writer mutex with the given
   335  // level, backed by the given rwLocker.
   336  func MakeLeveledRWMutex(level MutexLevel, rwLocker rwLocker) LeveledRWMutex {
   337  	return LeveledRWMutex{
   338  		level:    level,
   339  		rwLocker: rwLocker,
   340  	}
   341  }
   342  
   343  // Lock locks the associated locker.
   344  func (rw LeveledRWMutex) Lock(lockState *LockState) {
   345  	err := lockState.doLock(rw.level, writeExclusion, rw.rwLocker)
   346  	if err != nil {
   347  		panic(err)
   348  	}
   349  }
   350  
   351  // Unlock unlocks the associated locker.
   352  func (rw LeveledRWMutex) Unlock(lockState *LockState) {
   353  	err := lockState.doUnlock(rw.level, writeExclusion, rw.rwLocker)
   354  	if err != nil {
   355  		panic(err)
   356  	}
   357  }
   358  
   359  // RLock locks the associated locker for reading.
   360  func (rw LeveledRWMutex) RLock(lockState *LockState) {
   361  	err := lockState.doLock(rw.level, readExclusion, rw.rwLocker.RLocker())
   362  	if err != nil {
   363  		panic(err)
   364  	}
   365  }
   366  
   367  // RUnlock unlocks the associated locker for reading.
   368  func (rw LeveledRWMutex) RUnlock(lockState *LockState) {
   369  	err := lockState.doUnlock(rw.level, readExclusion, rw.rwLocker.RLocker())
   370  	if err != nil {
   371  		panic(err)
   372  	}
   373  }
   374  
   375  // AssertUnlocked does nothing if m is unlocked with respect to the
   376  // given LockState. Otherwise, it panics.
   377  func (rw LeveledRWMutex) AssertUnlocked(lockState *LockState) {
   378  	et := lockState.getExclusionType(rw.level)
   379  	if et != nonExclusion {
   380  		panic(unexpectedExclusionError{
   381  			levelToString: lockState.levelToString,
   382  			level:         rw.level,
   383  			exclusionType: et,
   384  		})
   385  	}
   386  }
   387  
   388  // AssertLocked does nothing if m is locked with respect to the given
   389  // LockState. Otherwise, it panics.
   390  func (rw LeveledRWMutex) AssertLocked(lockState *LockState) {
   391  	et := lockState.getExclusionType(rw.level)
   392  	if et != writeExclusion {
   393  		panic(unexpectedExclusionTypeError{
   394  			levelToString:         lockState.levelToString,
   395  			level:                 rw.level,
   396  			expectedExclusionType: writeExclusion,
   397  			exclusionType:         et,
   398  		})
   399  	}
   400  }
   401  
   402  // AssertRLocked does nothing if m is r-locked with respect to the
   403  // given LockState. Otherwise, it panics.
   404  func (rw LeveledRWMutex) AssertRLocked(lockState *LockState) {
   405  	et := lockState.getExclusionType(rw.level)
   406  	if et != readExclusion {
   407  		panic(unexpectedExclusionTypeError{
   408  			levelToString:         lockState.levelToString,
   409  			level:                 rw.level,
   410  			expectedExclusionType: readExclusion,
   411  			exclusionType:         et,
   412  		})
   413  	}
   414  }
   415  
   416  type unexpectedNonExclusionError struct {
   417  	levelToString func(MutexLevel) string
   418  	level         MutexLevel
   419  }
   420  
   421  func (e unexpectedNonExclusionError) Error() string {
   422  	return fmt.Sprintf("%s unexpectedly unlocked", e.levelToString(e.level))
   423  }
   424  
   425  // AssertAnyLocked does nothing if m is locked or r-locked with
   426  // respect to the given LockState. Otherwise, it panics.
   427  func (rw LeveledRWMutex) AssertAnyLocked(lockState *LockState) {
   428  	et := lockState.getExclusionType(rw.level)
   429  	if et == nonExclusion {
   430  		panic(unexpectedNonExclusionError{
   431  			levelToString: lockState.levelToString,
   432  			level:         rw.level,
   433  		})
   434  	}
   435  }
   436  
   437  // RLocker implements the RWMutex interface for LeveledRMMutex.
   438  func (rw LeveledRWMutex) RLocker() LeveledLocker {
   439  	return (leveledRLocker)(rw)
   440  }
   441  
   442  type leveledRLocker LeveledRWMutex
   443  
   444  func (r leveledRLocker) Lock(lockState *LockState) {
   445  	(LeveledRWMutex)(r).RLock(lockState)
   446  }
   447  
   448  func (r leveledRLocker) Unlock(lockState *LockState) {
   449  	(LeveledRWMutex)(r).RUnlock(lockState)
   450  }