github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/cmn/cos/sync.go (about) 1 // Package cos provides common low-level types and utilities for all aistore projects 2 /* 3 * Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved. 4 */ 5 package cos 6 7 import ( 8 "fmt" 9 "sync" 10 "time" 11 12 "github.com/NVIDIA/aistore/cmn/atomic" 13 "github.com/NVIDIA/aistore/cmn/debug" 14 ) 15 16 const ( 17 // Number of sync maps 18 MultiSyncMapCount = 0x40 // m.b. a power of two 19 MultiSyncMapMask = MultiSyncMapCount - 1 20 ) 21 22 type ( 23 // TimeoutGroup is similar to sync.WaitGroup with the difference on Wait 24 // where we only allow timing out. 25 // 26 // WARNING: It should not be used in critical code as it may have worse 27 // performance than sync.WaitGroup - use only if its needed. 28 // 29 // WARNING: It is not safe to wait on completion in multiple threads! 30 // 31 // WARNING: It is not recommended to reuse the TimeoutGroup - it was not 32 // designed for that and bugs can be expected, especially when previous 33 // group was not called with successful (without timeout) WaitTimeout. 34 TimeoutGroup struct { 35 fin chan struct{} 36 pending atomic.Int32 37 postedFin atomic.Int32 38 } 39 40 // StopCh is a channel for stopping running things. 41 StopCh struct { 42 ch chan struct{} 43 stopped atomic.Bool 44 } 45 46 // Semaphore is a textbook _sempahore_ implemented as a wrapper on `chan struct{}`. 47 Semaphore struct { 48 s chan struct{} 49 } 50 51 // DynSemaphore implements sempahore which can change its size during usage. 52 DynSemaphore struct { 53 c *sync.Cond 54 size int 55 cur int 56 mu sync.Mutex 57 } 58 59 // WG is an interface for wait group 60 WG interface { 61 Add(int) 62 Done() 63 Wait() 64 } 65 66 // LimitedWaitGroup is helper struct which combines standard wait group and 67 // semaphore to limit the number of goroutines created. 68 LimitedWaitGroup struct { 69 wg *sync.WaitGroup 70 sema *DynSemaphore 71 } 72 73 MultiSyncMap struct { 74 M [MultiSyncMapCount]sync.Map 75 } 76 77 NopLocker struct{} 78 ) 79 80 // interface guard 81 var ( 82 _ WG = (*LimitedWaitGroup)(nil) 83 _ WG = (*TimeoutGroup)(nil) 84 ) 85 86 /////////////// 87 // NopLocker // 88 /////////////// 89 90 func (NopLocker) Lock() {} 91 func (NopLocker) Unlock() {} 92 93 ////////////////// 94 // TimeoutGroup // 95 ////////////////// 96 97 func NewTimeoutGroup() *TimeoutGroup { 98 return &TimeoutGroup{ 99 fin: make(chan struct{}, 1), 100 } 101 } 102 103 func (twg *TimeoutGroup) Add(n int) { 104 twg.pending.Add(int32(n)) 105 } 106 107 // Wait waits until the Added pending count goes to zero. 108 // NOTE: must be invoked after _all_ Adds. 109 func (twg *TimeoutGroup) Wait() { 110 twg.WaitTimeoutWithStop(24*time.Hour, nil) 111 } 112 113 // Wait waits until the Added pending count goes to zero _or_ timeout. 114 // NOTE: must be invoked after _all_ Adds. 115 func (twg *TimeoutGroup) WaitTimeout(timeout time.Duration) bool { 116 timed, _ := twg.WaitTimeoutWithStop(timeout, nil) 117 return timed 118 } 119 120 // Wait waits until the Added pending count goes to zero _or_ timeout _or_ stop. 121 // NOTE: must be invoked after _all_ Adds. 122 func (twg *TimeoutGroup) WaitTimeoutWithStop(timeout time.Duration, stop <-chan struct{}) (timed, stopped bool) { 123 t := time.NewTimer(timeout) 124 select { 125 case <-twg.fin: 126 twg.postedFin.Store(0) 127 case <-t.C: 128 timed, stopped = true, false 129 case <-stop: 130 timed, stopped = false, true 131 } 132 t.Stop() 133 return 134 } 135 136 // Done decrements number of jobs left to do. Panics if the number jobs left is 137 // less than 0. 138 func (twg *TimeoutGroup) Done() { 139 if n := twg.pending.Dec(); n == 0 { 140 if posted := twg.postedFin.Swap(1); posted == 0 { 141 twg.fin <- struct{}{} 142 } 143 } else if n < 0 { 144 AssertMsg(false, fmt.Sprintf("invalid num pending %d", n)) 145 } 146 } 147 148 //////////// 149 // StopCh // 150 //////////// 151 152 func NewStopCh() *StopCh { 153 return &StopCh{ch: make(chan struct{}, 1)} 154 } 155 156 func (sch *StopCh) Init() { 157 debug.Assert(sch.ch == nil && !sch.stopped.Load()) 158 sch.ch = make(chan struct{}, 1) 159 } 160 161 func (sch *StopCh) Listen() <-chan struct{} { 162 return sch.ch 163 } 164 165 func (sch *StopCh) Close() { 166 if sch.stopped.CAS(false, true) { 167 close(sch.ch) 168 } 169 } 170 171 /////////////// 172 // Semaphore // 173 /////////////// 174 175 func NewSemaphore(n int) *Semaphore { 176 s := &Semaphore{s: make(chan struct{}, n)} 177 for range n { 178 s.s <- struct{}{} 179 } 180 return s 181 } 182 func (s *Semaphore) TryAcquire() <-chan struct{} { return s.s } 183 func (s *Semaphore) Acquire() { <-s.TryAcquire() } 184 func (s *Semaphore) Release() { s.s <- struct{}{} } 185 186 func NewDynSemaphore(n int) *DynSemaphore { 187 sema := &DynSemaphore{size: n} 188 sema.c = sync.NewCond(&sema.mu) 189 return sema 190 } 191 192 ////////////////// 193 // DynSemaphore // 194 ////////////////// 195 196 func (s *DynSemaphore) Size() int { 197 s.mu.Lock() 198 size := s.size 199 s.mu.Unlock() 200 return size 201 } 202 203 func (s *DynSemaphore) SetSize(n int) { 204 Assert(n >= 1) 205 s.mu.Lock() 206 s.size = n 207 s.mu.Unlock() 208 } 209 210 func (s *DynSemaphore) Acquire(cnts ...int) { 211 cnt := 1 212 if len(cnts) > 0 { 213 cnt = cnts[0] 214 } 215 s.mu.Lock() 216 check: 217 if s.cur+cnt <= s.size { 218 s.cur += cnt 219 s.mu.Unlock() 220 return 221 } 222 223 // Wait for vacant place(s) 224 s.c.Wait() 225 goto check 226 } 227 228 func (s *DynSemaphore) Release(cnts ...int) { 229 cnt := 1 230 if len(cnts) > 0 { 231 cnt = cnts[0] 232 } 233 234 s.mu.Lock() 235 236 Assert(s.cur >= cnt) 237 238 s.cur -= cnt 239 s.c.Broadcast() 240 s.mu.Unlock() 241 } 242 243 ////////////////////// 244 // LimitedWaitGroup // 245 ////////////////////// 246 247 // usage: no more than `limit` (e.g., sys.NumCPU()) goroutines in parallel 248 func NewLimitedWaitGroup(limit, wanted int) WG { 249 debug.Assert(limit > 0 || wanted > 0) 250 if wanted == 0 || wanted > limit { 251 return &LimitedWaitGroup{wg: &sync.WaitGroup{}, sema: NewDynSemaphore(limit)} 252 } 253 return &sync.WaitGroup{} 254 } 255 256 func (lwg *LimitedWaitGroup) Add(n int) { 257 lwg.sema.Acquire(n) 258 lwg.wg.Add(n) 259 } 260 261 func (lwg *LimitedWaitGroup) Done() { 262 lwg.sema.Release() 263 lwg.wg.Done() 264 } 265 266 func (lwg *LimitedWaitGroup) Wait() { 267 lwg.wg.Wait() 268 } 269 270 ////////////////// 271 // MultiSyncMap // 272 ////////////////// 273 274 func (msm *MultiSyncMap) Get(idx int) *sync.Map { 275 Assert(idx >= 0 && idx < MultiSyncMapCount) 276 return &msm.M[idx] 277 } 278 279 func (msm *MultiSyncMap) GetByHash(hash uint32) *sync.Map { 280 return &msm.M[hash%MultiSyncMapCount] 281 }