github.com/andy2046/gopie@v0.7.0/pkg/barrier/barrier.go (about)

     1  // Package barrier provides a barrier implementation.
     2  package barrier
     3  
     4  import (
     5  	"context"
     6  	"errors"
     7  	"sync"
     8  )
     9  
    10  // Barrier is a synchronizer that allows members to wait for each other.
    11  type Barrier struct {
    12  	count      int
    13  	n          int
    14  	isBroken   bool
    15  	waitChan   chan struct{}
    16  	fallChan   chan struct{}
    17  	brokenChan chan struct{}
    18  	mu         sync.RWMutex
    19  }
    20  
    21  var (
    22  	// ErrBroken when barrier is broken.
    23  	ErrBroken = errors.New("barrier is broken")
    24  )
    25  
    26  // New returns a new barrier.
    27  func New(n int) *Barrier {
    28  	if n <= 0 {
    29  		panic("number of members must be positive int")
    30  	}
    31  	b := &Barrier{
    32  		n:          n,
    33  		waitChan:   make(chan struct{}),
    34  		brokenChan: make(chan struct{}),
    35  	}
    36  	return b
    37  }
    38  
    39  // Await waits until all members have called await on the barrier.
    40  func (b *Barrier) Await(ctx context.Context) error {
    41  	select {
    42  	case <-ctx.Done():
    43  		return ctx.Err()
    44  	default:
    45  	}
    46  
    47  	b.mu.Lock()
    48  
    49  	if b.isBroken {
    50  		b.mu.Unlock()
    51  		return ErrBroken
    52  	}
    53  
    54  	b.count++
    55  	waitChan := b.waitChan
    56  	brokenChan := b.brokenChan
    57  	count := b.count
    58  
    59  	if count == b.n {
    60  		b.reset(true)
    61  		b.mu.Unlock()
    62  		return nil
    63  	}
    64  
    65  	b.mu.Unlock()
    66  
    67  	select {
    68  	case <-waitChan:
    69  		return nil
    70  	case <-brokenChan:
    71  		return ErrBroken
    72  	case <-ctx.Done():
    73  		b.broke(true)
    74  		return ctx.Err()
    75  	}
    76  }
    77  
    78  func (b *Barrier) broke(toLock bool) {
    79  	if toLock {
    80  		b.mu.Lock()
    81  		defer b.mu.Unlock()
    82  	}
    83  
    84  	if !b.isBroken {
    85  		b.isBroken = true
    86  		close(b.brokenChan)
    87  	}
    88  }
    89  
    90  func (b *Barrier) reset(ok bool) {
    91  	if ok {
    92  		close(b.waitChan)
    93  	} else if b.count > 0 {
    94  		b.broke(false)
    95  	}
    96  
    97  	b.waitChan = make(chan struct{})
    98  	b.brokenChan = make(chan struct{})
    99  	b.count = 0
   100  	b.isBroken = false
   101  }
   102  
   103  // Reset resets the barrier to initial state.
   104  func (b *Barrier) Reset() {
   105  	b.mu.Lock()
   106  	defer b.mu.Unlock()
   107  	b.reset(false)
   108  }
   109  
   110  // N returns the number of members for the barrier.
   111  func (b *Barrier) N() int {
   112  	return b.n
   113  }
   114  
   115  // NWaiting returns the number of members currently waiting at the barrier.
   116  func (b *Barrier) NWaiting() int {
   117  	b.mu.RLock()
   118  	defer b.mu.RUnlock()
   119  	return b.count
   120  }
   121  
   122  // IsBroken returns true if the barrier is broken.
   123  func (b *Barrier) IsBroken() bool {
   124  	b.mu.RLock()
   125  	defer b.mu.RUnlock()
   126  	return b.isBroken
   127  }