github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kbfs/kbfssync/leveled_mutex.go (about) 1 // Copyright 2016 Keybase Inc. All rights reserved. 2 // Use of this source code is governed by a BSD 3 // license that can be found in the LICENSE file. 4 5 package kbfssync 6 7 import ( 8 "fmt" 9 "sync" 10 "sync/atomic" 11 ) 12 13 // The LeveledMutex, LeveledRWMutex, and LockState types enables a 14 // lock hierarchy to be checked. For a program (or subsystem), each 15 // (rw-)mutex must have a unique associated MutexLevel, which means 16 // that a (rw-)mutex must not be (r-)locked before another (rw-)mutex 17 // with a lower MutexLevel in a given execution flow. This is achieved 18 // by creating a new LockState at the start of an execution flow and 19 // passing it to the (r-)lock/(r-)unlock methods of each (rw-)mutex. 20 // 21 // TODO: Once this becomes a bottleneck, add a +build production 22 // version that stubs everything out. 23 24 // An exclusiveLock is a lock around something that is expected to be 25 // accessed exclusively. It immediately panics upon any lock 26 // contention. 27 type exclusiveLock struct { 28 v *int32 29 } 30 31 func makeExclusiveLock() exclusiveLock { 32 return exclusiveLock{ 33 v: new(int32), 34 } 35 } 36 37 func (l exclusiveLock) lock() { 38 if !atomic.CompareAndSwapInt32(l.v, 0, 1) { 39 panic("unexpected concurrent access") 40 } 41 } 42 43 func (l exclusiveLock) unlock() { 44 if !atomic.CompareAndSwapInt32(l.v, 1, 0) { 45 panic("unexpected concurrent access") 46 } 47 } 48 49 // MutexLevel is the level for a mutex, which must be unique to that 50 // mutex. 51 type MutexLevel int 52 53 // exclusionType is the type of exclusion of a lock. A regular lock 54 // always uses write exclusion, where only one thing at a time can 55 // hold the lock, whereas a reader-writer lock can do either write 56 // exclusion or read exclusion, where only one writer or any number of 57 // readers can hold the lock. 58 type exclusionType int 59 60 const ( 61 nonExclusion exclusionType = 0 62 writeExclusion exclusionType = 1 63 readExclusion exclusionType = 2 64 ) 65 66 func (et exclusionType) prefix() string { 67 switch et { 68 case nonExclusion: 69 return "Un" 70 case writeExclusion: 71 return "" 72 case readExclusion: 73 return "R" 74 } 75 return fmt.Sprintf("exclusionType{%d}", et) 76 } 77 78 // exclusionState holds the state for a held mutex. 79 type exclusionState struct { 80 // The level of the held mutex. 81 level MutexLevel 82 // The exclusion type of the held mutex. 83 exclusionType exclusionType 84 } 85 86 // LockState holds the info regarding which level mutexes are held or 87 // not for a particular execution flow. 88 type LockState struct { 89 levelToString func(MutexLevel) string 90 91 // Protects exclusionStates. 92 exclusionStatesLock exclusiveLock 93 // The stack of held mutexes, ordered by increasing level. 94 exclusionStates []exclusionState 95 } 96 97 // MakeLevelState returns a new LockState. This must be called at the 98 // start of a new execution flow and passed to any LeveledMutex or 99 // LeveledRWMutex operation during that execution flow. 100 // 101 // TODO: Consider adding a parameter to set the capacity of 102 // exclusionStates. 103 func MakeLevelState(levelToString func(MutexLevel) string) *LockState { 104 return &LockState{ 105 levelToString: levelToString, 106 exclusionStatesLock: makeExclusiveLock(), 107 } 108 } 109 110 // currLocked returns the current exclusion state, or nil if there is 111 // none. 112 func (state *LockState) currLocked() *exclusionState { 113 stateCount := len(state.exclusionStates) 114 if stateCount == 0 { 115 return nil 116 } 117 return &state.exclusionStates[stateCount-1] 118 } 119 120 type levelViolationError struct { 121 levelToString func(MutexLevel) string 122 level MutexLevel 123 exclusionType exclusionType 124 curr exclusionState 125 } 126 127 func (e levelViolationError) Error() string { 128 return fmt.Sprintf("level violation: %s %sLocked after %s %sLocked", 129 e.levelToString(e.level), e.exclusionType.prefix(), 130 e.levelToString(e.curr.level), e.curr.exclusionType.prefix()) 131 } 132 133 func (state *LockState) doLock( 134 level MutexLevel, exclusionType exclusionType, lock sync.Locker) error { 135 state.exclusionStatesLock.lock() 136 defer state.exclusionStatesLock.unlock() 137 138 curr := state.currLocked() 139 140 if curr != nil && level <= curr.level { 141 return levelViolationError{ 142 levelToString: state.levelToString, 143 level: level, 144 exclusionType: exclusionType, 145 curr: *curr, 146 } 147 } 148 149 lock.Lock() 150 151 state.exclusionStates = append(state.exclusionStates, exclusionState{ 152 level: level, 153 exclusionType: exclusionType, 154 }) 155 return nil 156 } 157 158 type danglingUnlockError struct { 159 levelToString func(MutexLevel) string 160 level MutexLevel 161 exclusionType exclusionType 162 } 163 164 func (e danglingUnlockError) Error() string { 165 return fmt.Sprintf("%s %sUnlocked while already unlocked", 166 e.levelToString(e.level), e.exclusionType.prefix()) 167 } 168 169 type mismatchedUnlockError struct { 170 levelToString func(MutexLevel) string 171 level MutexLevel 172 exclusionType exclusionType 173 curr exclusionState 174 } 175 176 func (e mismatchedUnlockError) Error() string { 177 return fmt.Sprintf( 178 "%sUnlock call for %s doesn't match %sLock call for %s", 179 e.exclusionType.prefix(), e.levelToString(e.level), 180 e.curr.exclusionType.prefix(), e.levelToString(e.curr.level)) 181 } 182 183 func (state *LockState) doUnlock( 184 level MutexLevel, exclusionType exclusionType, lock sync.Locker) error { 185 state.exclusionStatesLock.lock() 186 defer state.exclusionStatesLock.unlock() 187 188 curr := state.currLocked() 189 190 if curr == nil { 191 return danglingUnlockError{ 192 levelToString: state.levelToString, 193 level: level, 194 exclusionType: exclusionType, 195 } 196 } 197 198 if level != curr.level || curr.exclusionType != exclusionType { 199 return mismatchedUnlockError{ 200 levelToString: state.levelToString, 201 level: level, 202 exclusionType: exclusionType, 203 curr: *curr, 204 } 205 } 206 207 lock.Unlock() 208 209 state.exclusionStates = state.exclusionStates[:len(state.exclusionStates)-1] 210 return nil 211 } 212 213 // getExclusionType returns returns the exclusionType for the given 214 // MutexLevel, or nonExclusion if there is none. 215 func (state *LockState) getExclusionType(level MutexLevel) exclusionType { 216 state.exclusionStatesLock.lock() 217 defer state.exclusionStatesLock.unlock() 218 219 // Not worth it to do anything more complicated than a 220 // brute-force search. 221 for _, state := range state.exclusionStates { 222 if state.level > level { 223 break 224 } 225 if state.level == level { 226 return state.exclusionType 227 } 228 } 229 230 return nonExclusion 231 } 232 233 // LeveledMutex is a mutex with an associated level, which must be 234 // unique. Note that unlike sync.Mutex, LeveledMutex is a reference 235 // type and not a value type. 236 type LeveledMutex struct { 237 level MutexLevel 238 locker sync.Locker 239 } 240 241 // MakeLeveledMutex makes a mutex with the given level, backed by the 242 // given locker. 243 func MakeLeveledMutex(level MutexLevel, locker sync.Locker) LeveledMutex { 244 return LeveledMutex{ 245 level: level, 246 locker: locker, 247 } 248 } 249 250 // Lock locks the associated locker. 251 func (m LeveledMutex) Lock(lockState *LockState) { 252 err := lockState.doLock(m.level, writeExclusion, m.locker) 253 if err != nil { 254 panic(err) 255 } 256 } 257 258 // Unlock locks the associated locker. 259 func (m LeveledMutex) Unlock(lockState *LockState) { 260 err := lockState.doUnlock(m.level, writeExclusion, m.locker) 261 if err != nil { 262 panic(err) 263 } 264 } 265 266 type unexpectedExclusionError struct { 267 levelToString func(MutexLevel) string 268 level MutexLevel 269 exclusionType exclusionType 270 } 271 272 func (e unexpectedExclusionError) Error() string { 273 return fmt.Sprintf("%s unexpectedly %sLocked", 274 e.levelToString(e.level), e.exclusionType.prefix()) 275 } 276 277 // AssertUnlocked does nothing if m is unlocked with respect to the 278 // given LockState. Otherwise, it panics. 279 func (m LeveledMutex) AssertUnlocked(lockState *LockState) { 280 et := lockState.getExclusionType(m.level) 281 if et != nonExclusion { 282 panic(unexpectedExclusionError{ 283 levelToString: lockState.levelToString, 284 level: m.level, 285 exclusionType: et, 286 }) 287 } 288 } 289 290 type unexpectedExclusionTypeError struct { 291 levelToString func(MutexLevel) string 292 level MutexLevel 293 expectedExclusionType exclusionType 294 exclusionType exclusionType 295 } 296 297 func (e unexpectedExclusionTypeError) Error() string { 298 return fmt.Sprintf( 299 "%s unexpectedly not %sLocked; instead it is %sLocked", 300 e.levelToString(e.level), 301 e.expectedExclusionType.prefix(), 302 e.exclusionType.prefix()) 303 } 304 305 // AssertLocked does nothing if m is locked with respect to the given 306 // LockState. Otherwise, it panics. 307 func (m LeveledMutex) AssertLocked(lockState *LockState) { 308 et := lockState.getExclusionType(m.level) 309 if et != writeExclusion { 310 panic(unexpectedExclusionTypeError{ 311 levelToString: lockState.levelToString, 312 level: m.level, 313 expectedExclusionType: writeExclusion, 314 exclusionType: et, 315 }) 316 } 317 } 318 319 // LeveledLocker represents an object that can be locked and unlocked 320 // with a LockState. 321 type LeveledLocker interface { 322 Lock(*LockState) 323 Unlock(*LockState) 324 } 325 326 // LeveledRWMutex is a reader-writer mutex with an associated level, 327 // which must be unique. Note that unlike sync.RWMutex, LeveledRWMutex 328 // is a reference type and not a value type. 329 type LeveledRWMutex struct { 330 level MutexLevel 331 rwLocker rwLocker 332 } 333 334 // MakeLeveledRWMutex makes a reader-writer mutex with the given 335 // level, backed by the given rwLocker. 336 func MakeLeveledRWMutex(level MutexLevel, rwLocker rwLocker) LeveledRWMutex { 337 return LeveledRWMutex{ 338 level: level, 339 rwLocker: rwLocker, 340 } 341 } 342 343 // Lock locks the associated locker. 344 func (rw LeveledRWMutex) Lock(lockState *LockState) { 345 err := lockState.doLock(rw.level, writeExclusion, rw.rwLocker) 346 if err != nil { 347 panic(err) 348 } 349 } 350 351 // Unlock unlocks the associated locker. 352 func (rw LeveledRWMutex) Unlock(lockState *LockState) { 353 err := lockState.doUnlock(rw.level, writeExclusion, rw.rwLocker) 354 if err != nil { 355 panic(err) 356 } 357 } 358 359 // RLock locks the associated locker for reading. 360 func (rw LeveledRWMutex) RLock(lockState *LockState) { 361 err := lockState.doLock(rw.level, readExclusion, rw.rwLocker.RLocker()) 362 if err != nil { 363 panic(err) 364 } 365 } 366 367 // RUnlock unlocks the associated locker for reading. 368 func (rw LeveledRWMutex) RUnlock(lockState *LockState) { 369 err := lockState.doUnlock(rw.level, readExclusion, rw.rwLocker.RLocker()) 370 if err != nil { 371 panic(err) 372 } 373 } 374 375 // AssertUnlocked does nothing if m is unlocked with respect to the 376 // given LockState. Otherwise, it panics. 377 func (rw LeveledRWMutex) AssertUnlocked(lockState *LockState) { 378 et := lockState.getExclusionType(rw.level) 379 if et != nonExclusion { 380 panic(unexpectedExclusionError{ 381 levelToString: lockState.levelToString, 382 level: rw.level, 383 exclusionType: et, 384 }) 385 } 386 } 387 388 // AssertLocked does nothing if m is locked with respect to the given 389 // LockState. Otherwise, it panics. 390 func (rw LeveledRWMutex) AssertLocked(lockState *LockState) { 391 et := lockState.getExclusionType(rw.level) 392 if et != writeExclusion { 393 panic(unexpectedExclusionTypeError{ 394 levelToString: lockState.levelToString, 395 level: rw.level, 396 expectedExclusionType: writeExclusion, 397 exclusionType: et, 398 }) 399 } 400 } 401 402 // AssertRLocked does nothing if m is r-locked with respect to the 403 // given LockState. Otherwise, it panics. 404 func (rw LeveledRWMutex) AssertRLocked(lockState *LockState) { 405 et := lockState.getExclusionType(rw.level) 406 if et != readExclusion { 407 panic(unexpectedExclusionTypeError{ 408 levelToString: lockState.levelToString, 409 level: rw.level, 410 expectedExclusionType: readExclusion, 411 exclusionType: et, 412 }) 413 } 414 } 415 416 type unexpectedNonExclusionError struct { 417 levelToString func(MutexLevel) string 418 level MutexLevel 419 } 420 421 func (e unexpectedNonExclusionError) Error() string { 422 return fmt.Sprintf("%s unexpectedly unlocked", e.levelToString(e.level)) 423 } 424 425 // AssertAnyLocked does nothing if m is locked or r-locked with 426 // respect to the given LockState. Otherwise, it panics. 427 func (rw LeveledRWMutex) AssertAnyLocked(lockState *LockState) { 428 et := lockState.getExclusionType(rw.level) 429 if et == nonExclusion { 430 panic(unexpectedNonExclusionError{ 431 levelToString: lockState.levelToString, 432 level: rw.level, 433 }) 434 } 435 } 436 437 // RLocker implements the RWMutex interface for LeveledRMMutex. 438 func (rw LeveledRWMutex) RLocker() LeveledLocker { 439 return (leveledRLocker)(rw) 440 } 441 442 type leveledRLocker LeveledRWMutex 443 444 func (r leveledRLocker) Lock(lockState *LockState) { 445 (LeveledRWMutex)(r).RLock(lockState) 446 } 447 448 func (r leveledRLocker) Unlock(lockState *LockState) { 449 (LeveledRWMutex)(r).RUnlock(lockState) 450 }