github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/cmn/cos/sync.go (about)

     1  // Package cos provides common low-level types and utilities for all aistore projects
     2  /*
     3   * Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved.
     4   */
     5  package cos
     6  
     7  import (
     8  	"fmt"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/NVIDIA/aistore/cmn/atomic"
    13  	"github.com/NVIDIA/aistore/cmn/debug"
    14  )
    15  
    16  const (
    17  	// Number of sync maps
    18  	MultiSyncMapCount = 0x40 // m.b. a power of two
    19  	MultiSyncMapMask  = MultiSyncMapCount - 1
    20  )
    21  
    22  type (
    23  	// TimeoutGroup is similar to sync.WaitGroup with the difference on Wait
    24  	// where we only allow timing out.
    25  	//
    26  	// WARNING: It should not be used in critical code as it may have worse
    27  	// performance than sync.WaitGroup - use only if its needed.
    28  	//
    29  	// WARNING: It is not safe to wait on completion in multiple threads!
    30  	//
    31  	// WARNING: It is not recommended to reuse the TimeoutGroup - it was not
    32  	// designed for that and bugs can be expected, especially when previous
    33  	// group was not called with successful (without timeout) WaitTimeout.
    34  	TimeoutGroup struct {
    35  		fin       chan struct{}
    36  		pending   atomic.Int32
    37  		postedFin atomic.Int32
    38  	}
    39  
    40  	// StopCh is a channel for stopping running things.
    41  	StopCh struct {
    42  		ch      chan struct{}
    43  		stopped atomic.Bool
    44  	}
    45  
    46  	// Semaphore is a textbook _sempahore_ implemented as a wrapper on `chan struct{}`.
    47  	Semaphore struct {
    48  		s chan struct{}
    49  	}
    50  
    51  	// DynSemaphore implements sempahore which can change its size during usage.
    52  	DynSemaphore struct {
    53  		c    *sync.Cond
    54  		size int
    55  		cur  int
    56  		mu   sync.Mutex
    57  	}
    58  
    59  	// WG is an interface for wait group
    60  	WG interface {
    61  		Add(int)
    62  		Done()
    63  		Wait()
    64  	}
    65  
    66  	// LimitedWaitGroup is helper struct which combines standard wait group and
    67  	// semaphore to limit the number of goroutines created.
    68  	LimitedWaitGroup struct {
    69  		wg   *sync.WaitGroup
    70  		sema *DynSemaphore
    71  	}
    72  
    73  	MultiSyncMap struct {
    74  		M [MultiSyncMapCount]sync.Map
    75  	}
    76  
    77  	NopLocker struct{}
    78  )
    79  
    80  // interface guard
    81  var (
    82  	_ WG = (*LimitedWaitGroup)(nil)
    83  	_ WG = (*TimeoutGroup)(nil)
    84  )
    85  
    86  ///////////////
    87  // NopLocker //
    88  ///////////////
    89  
    90  func (NopLocker) Lock()   {}
    91  func (NopLocker) Unlock() {}
    92  
    93  //////////////////
    94  // TimeoutGroup //
    95  //////////////////
    96  
    97  func NewTimeoutGroup() *TimeoutGroup {
    98  	return &TimeoutGroup{
    99  		fin: make(chan struct{}, 1),
   100  	}
   101  }
   102  
   103  func (twg *TimeoutGroup) Add(n int) {
   104  	twg.pending.Add(int32(n))
   105  }
   106  
   107  // Wait waits until the Added pending count goes to zero.
   108  // NOTE: must be invoked after _all_ Adds.
   109  func (twg *TimeoutGroup) Wait() {
   110  	twg.WaitTimeoutWithStop(24*time.Hour, nil)
   111  }
   112  
   113  // Wait waits until the Added pending count goes to zero _or_ timeout.
   114  // NOTE: must be invoked after _all_ Adds.
   115  func (twg *TimeoutGroup) WaitTimeout(timeout time.Duration) bool {
   116  	timed, _ := twg.WaitTimeoutWithStop(timeout, nil)
   117  	return timed
   118  }
   119  
   120  // Wait waits until the Added pending count goes to zero _or_ timeout _or_ stop.
   121  // NOTE: must be invoked after _all_ Adds.
   122  func (twg *TimeoutGroup) WaitTimeoutWithStop(timeout time.Duration, stop <-chan struct{}) (timed, stopped bool) {
   123  	t := time.NewTimer(timeout)
   124  	select {
   125  	case <-twg.fin:
   126  		twg.postedFin.Store(0)
   127  	case <-t.C:
   128  		timed, stopped = true, false
   129  	case <-stop:
   130  		timed, stopped = false, true
   131  	}
   132  	t.Stop()
   133  	return
   134  }
   135  
   136  // Done decrements number of jobs left to do. Panics if the number jobs left is
   137  // less than 0.
   138  func (twg *TimeoutGroup) Done() {
   139  	if n := twg.pending.Dec(); n == 0 {
   140  		if posted := twg.postedFin.Swap(1); posted == 0 {
   141  			twg.fin <- struct{}{}
   142  		}
   143  	} else if n < 0 {
   144  		AssertMsg(false, fmt.Sprintf("invalid num pending %d", n))
   145  	}
   146  }
   147  
   148  ////////////
   149  // StopCh //
   150  ////////////
   151  
   152  func NewStopCh() *StopCh {
   153  	return &StopCh{ch: make(chan struct{}, 1)}
   154  }
   155  
   156  func (sch *StopCh) Init() {
   157  	debug.Assert(sch.ch == nil && !sch.stopped.Load())
   158  	sch.ch = make(chan struct{}, 1)
   159  }
   160  
   161  func (sch *StopCh) Listen() <-chan struct{} {
   162  	return sch.ch
   163  }
   164  
   165  func (sch *StopCh) Close() {
   166  	if sch.stopped.CAS(false, true) {
   167  		close(sch.ch)
   168  	}
   169  }
   170  
   171  ///////////////
   172  // Semaphore //
   173  ///////////////
   174  
   175  func NewSemaphore(n int) *Semaphore {
   176  	s := &Semaphore{s: make(chan struct{}, n)}
   177  	for range n {
   178  		s.s <- struct{}{}
   179  	}
   180  	return s
   181  }
   182  func (s *Semaphore) TryAcquire() <-chan struct{} { return s.s }
   183  func (s *Semaphore) Acquire()                    { <-s.TryAcquire() }
   184  func (s *Semaphore) Release()                    { s.s <- struct{}{} }
   185  
   186  func NewDynSemaphore(n int) *DynSemaphore {
   187  	sema := &DynSemaphore{size: n}
   188  	sema.c = sync.NewCond(&sema.mu)
   189  	return sema
   190  }
   191  
   192  //////////////////
   193  // DynSemaphore //
   194  //////////////////
   195  
   196  func (s *DynSemaphore) Size() int {
   197  	s.mu.Lock()
   198  	size := s.size
   199  	s.mu.Unlock()
   200  	return size
   201  }
   202  
   203  func (s *DynSemaphore) SetSize(n int) {
   204  	Assert(n >= 1)
   205  	s.mu.Lock()
   206  	s.size = n
   207  	s.mu.Unlock()
   208  }
   209  
   210  func (s *DynSemaphore) Acquire(cnts ...int) {
   211  	cnt := 1
   212  	if len(cnts) > 0 {
   213  		cnt = cnts[0]
   214  	}
   215  	s.mu.Lock()
   216  check:
   217  	if s.cur+cnt <= s.size {
   218  		s.cur += cnt
   219  		s.mu.Unlock()
   220  		return
   221  	}
   222  
   223  	// Wait for vacant place(s)
   224  	s.c.Wait()
   225  	goto check
   226  }
   227  
   228  func (s *DynSemaphore) Release(cnts ...int) {
   229  	cnt := 1
   230  	if len(cnts) > 0 {
   231  		cnt = cnts[0]
   232  	}
   233  
   234  	s.mu.Lock()
   235  
   236  	Assert(s.cur >= cnt)
   237  
   238  	s.cur -= cnt
   239  	s.c.Broadcast()
   240  	s.mu.Unlock()
   241  }
   242  
   243  //////////////////////
   244  // LimitedWaitGroup //
   245  //////////////////////
   246  
   247  // usage: no more than `limit` (e.g., sys.NumCPU()) goroutines in parallel
   248  func NewLimitedWaitGroup(limit, wanted int) WG {
   249  	debug.Assert(limit > 0 || wanted > 0)
   250  	if wanted == 0 || wanted > limit {
   251  		return &LimitedWaitGroup{wg: &sync.WaitGroup{}, sema: NewDynSemaphore(limit)}
   252  	}
   253  	return &sync.WaitGroup{}
   254  }
   255  
   256  func (lwg *LimitedWaitGroup) Add(n int) {
   257  	lwg.sema.Acquire(n)
   258  	lwg.wg.Add(n)
   259  }
   260  
   261  func (lwg *LimitedWaitGroup) Done() {
   262  	lwg.sema.Release()
   263  	lwg.wg.Done()
   264  }
   265  
   266  func (lwg *LimitedWaitGroup) Wait() {
   267  	lwg.wg.Wait()
   268  }
   269  
   270  //////////////////
   271  // MultiSyncMap //
   272  //////////////////
   273  
   274  func (msm *MultiSyncMap) Get(idx int) *sync.Map {
   275  	Assert(idx >= 0 && idx < MultiSyncMapCount)
   276  	return &msm.M[idx]
   277  }
   278  
   279  func (msm *MultiSyncMap) GetByHash(hash uint32) *sync.Map {
   280  	return &msm.M[hash%MultiSyncMapCount]
   281  }