github.com/nutsdb/nutsdb@v1.0.4/throttle.go (about)

     1  package nutsdb
     2  
     3  import (
     4  	"sync"
     5  )
     6  
     7  // Throttle allows a limited number of workers to run at a time. It also
     8  // provides a mechanism to check for errors encountered by workers and wait for
     9  // them to finish.
    10  type Throttle struct {
    11  	once      sync.Once
    12  	wg        sync.WaitGroup
    13  	ch        chan struct{}
    14  	errCh     chan error
    15  	finishErr error
    16  }
    17  
    18  // NewThrottle creates a new throttle with a max number of workers.
    19  func NewThrottle(max int) *Throttle {
    20  	return &Throttle{
    21  		ch:    make(chan struct{}, max),
    22  		errCh: make(chan error, max),
    23  	}
    24  }
    25  
    26  // Do should be called by workers before they start working. It blocks if there
    27  // are already maximum number of workers working. If it detects an error from
    28  // previously Done workers, it would return it.
    29  func (t *Throttle) Do() error {
    30  	for {
    31  		select {
    32  		case t.ch <- struct{}{}:
    33  			t.wg.Add(1)
    34  			return nil
    35  		case err := <-t.errCh:
    36  			if err != nil {
    37  				return err
    38  			}
    39  		}
    40  	}
    41  }
    42  
    43  // Done should be called by workers when they finish working. They can also
    44  // pass the error status of work done.
    45  func (t *Throttle) Done(err error) {
    46  	if err != nil {
    47  		t.errCh <- err
    48  	}
    49  	select {
    50  	case <-t.ch:
    51  	default:
    52  		panic("Throttle Do Done mismatch")
    53  	}
    54  	t.wg.Done()
    55  }
    56  
    57  // Finish waits until all workers have finished working. It would return any error passed by Done.
    58  // If Finish is called multiple time, it will wait for workers to finish only once(first time).
    59  // From next calls, it will return same error as found on first call.
    60  func (t *Throttle) Finish() error {
    61  	t.once.Do(func() {
    62  		t.wg.Wait()
    63  		close(t.ch)
    64  		close(t.errCh)
    65  		for err := range t.errCh {
    66  			if err != nil {
    67  				t.finishErr = err
    68  				return
    69  			}
    70  		}
    71  	})
    72  
    73  	return t.finishErr
    74  }