github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/common/sharded_locks.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package common
    13  
    14  import "sync"
    15  
    16  const DefaultShardedLocksCount = 512
    17  
    18  type ShardedLocks struct {
    19  	// sharded locks
    20  	shards []sync.Mutex
    21  	// number of locks
    22  	count uint64
    23  }
    24  
    25  func NewDefaultShardedLocks() *ShardedLocks {
    26  	return NewShardedLocks(DefaultShardedLocksCount)
    27  }
    28  
    29  func NewShardedLocks(count uint64) *ShardedLocks {
    30  	if count < 2 {
    31  		count = 2
    32  	}
    33  
    34  	return &ShardedLocks{
    35  		shards: make([]sync.Mutex, count),
    36  		count:  count,
    37  	}
    38  }
    39  
    40  func (sl *ShardedLocks) LockAll() {
    41  	for i := uint64(0); i < sl.count; i++ {
    42  		sl.shards[i].Lock()
    43  	}
    44  }
    45  
    46  func (sl *ShardedLocks) UnlockAll() {
    47  	for i := int(sl.count) - 1; i >= 0; i-- {
    48  		sl.shards[i].Unlock()
    49  	}
    50  }
    51  
    52  func (sl *ShardedLocks) LockedAll(callback func()) {
    53  	sl.LockAll()
    54  	defer sl.UnlockAll()
    55  
    56  	callback()
    57  }
    58  
    59  func (sl *ShardedLocks) Lock(id uint64) {
    60  	sl.shards[id%sl.count].Lock()
    61  }
    62  
    63  func (sl *ShardedLocks) Unlock(id uint64) {
    64  	sl.shards[id%sl.count].Unlock()
    65  }
    66  
    67  func (sl *ShardedLocks) Locked(id uint64, callback func()) {
    68  	sl.Lock(id)
    69  	defer sl.Unlock(id)
    70  
    71  	callback()
    72  }
    73  
    74  type ShardedRWLocks struct {
    75  	// sharded locks
    76  	shards []sync.RWMutex
    77  	// number of locks
    78  	count uint64
    79  }
    80  
    81  func NewDefaultShardedRWLocks() *ShardedRWLocks {
    82  	return NewShardedRWLocks(DefaultShardedLocksCount)
    83  }
    84  
    85  func NewShardedRWLocks(count uint64) *ShardedRWLocks {
    86  	if count < 2 {
    87  		count = 2
    88  	}
    89  
    90  	return &ShardedRWLocks{
    91  		shards: make([]sync.RWMutex, count),
    92  		count:  count,
    93  	}
    94  }
    95  
    96  func (sl *ShardedRWLocks) LockAll() {
    97  	for i := uint64(0); i < sl.count; i++ {
    98  		sl.shards[i].Lock()
    99  	}
   100  }
   101  
   102  func (sl *ShardedRWLocks) UnlockAll() {
   103  	for i := int(sl.count) - 1; i >= 0; i-- {
   104  		sl.shards[i].Unlock()
   105  	}
   106  }
   107  
   108  func (sl *ShardedRWLocks) LockedAll(callback func()) {
   109  	sl.LockAll()
   110  	defer sl.UnlockAll()
   111  
   112  	callback()
   113  }
   114  
   115  func (sl *ShardedRWLocks) Lock(id uint64) {
   116  	sl.shards[id%sl.count].Lock()
   117  }
   118  
   119  func (sl *ShardedRWLocks) Unlock(id uint64) {
   120  	sl.shards[id%sl.count].Unlock()
   121  }
   122  
   123  func (sl *ShardedRWLocks) Locked(id uint64, callback func()) {
   124  	sl.Lock(id)
   125  	defer sl.Unlock(id)
   126  
   127  	callback()
   128  }
   129  
   130  func (sl *ShardedRWLocks) RLockAll() {
   131  	for i := uint64(0); i < sl.count; i++ {
   132  		sl.shards[i].RLock()
   133  	}
   134  }
   135  
   136  func (sl *ShardedRWLocks) RUnlockAll() {
   137  	for i := int(sl.count) - 1; i >= 0; i-- {
   138  		sl.shards[i].RUnlock()
   139  	}
   140  }
   141  
   142  func (sl *ShardedRWLocks) RLockedAll(callback func()) {
   143  	sl.RLockAll()
   144  	defer sl.RUnlockAll()
   145  
   146  	callback()
   147  }
   148  
   149  func (sl *ShardedRWLocks) RLock(id uint64) {
   150  	sl.shards[id%sl.count].RLock()
   151  }
   152  
   153  func (sl *ShardedRWLocks) RUnlock(id uint64) {
   154  	sl.shards[id%sl.count].RUnlock()
   155  }
   156  
   157  func (sl *ShardedRWLocks) RLocked(id uint64, callback func()) {
   158  	sl.RLock(id)
   159  	defer sl.RUnlock(id)
   160  
   161  	callback()
   162  }