github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kbfs/kbfssync/repeated_wait_group.go (about)

     1  // Copyright 2016 Keybase Inc. All rights reserved.
     2  // Use of this source code is governed by a BSD
     3  // license that can be found in the LICENSE file.
     4  
     5  package kbfssync
     6  
     7  import (
     8  	"sync"
     9  
    10  	"golang.org/x/net/context"
    11  )
    12  
    13  // RepeatedWaitGroup can be used in place of a sync.WaitGroup when
    14  // code may need to repeatedly wait for a set of tasks to finish.
    15  // (sync.WaitGroup requires special mutex usage to make this work
    16  // properly, which can easily lead to deadlocks.)  We use a mutex,
    17  // int, and channel to track and synchronize on the number of
    18  // outstanding tasks.
    19  type RepeatedWaitGroup struct {
    20  	lock     sync.Mutex
    21  	num      int
    22  	isIdleCh chan struct{} // leave as nil when initializing
    23  	// TODO: we could remove this paused bool by converting the
    24  	// `pauseCh` into an `onPauseCh` that starts off initialized and
    25  	// gets set to nil when a pause happens.  But that would require
    26  	// an initializer for the channel.
    27  	paused  bool
    28  	pauseCh chan struct{} // leave as nil when initializing
    29  }
    30  
    31  // Add indicates that a number of tasks have begun.
    32  func (rwg *RepeatedWaitGroup) Add(delta int) {
    33  	rwg.lock.Lock()
    34  	defer rwg.lock.Unlock()
    35  	if rwg.isIdleCh == nil {
    36  		rwg.isIdleCh = make(chan struct{})
    37  	}
    38  	if rwg.num+delta < 0 {
    39  		panic("RepeatedWaitGroup count would be negative")
    40  	}
    41  	rwg.num += delta
    42  	if rwg.num == 0 {
    43  		close(rwg.isIdleCh)
    44  		rwg.isIdleCh = nil
    45  	}
    46  }
    47  
    48  // Wait blocks until either the underlying task count goes to 0, or
    49  // the given context is canceled.
    50  func (rwg *RepeatedWaitGroup) Wait(ctx context.Context) error {
    51  	isIdleCh := func() chan struct{} {
    52  		rwg.lock.Lock()
    53  		defer rwg.lock.Unlock()
    54  		return rwg.isIdleCh
    55  	}()
    56  
    57  	if isIdleCh == nil {
    58  		return nil
    59  	}
    60  
    61  	select {
    62  	case <-isIdleCh:
    63  		return nil
    64  	case <-ctx.Done():
    65  		return ctx.Err()
    66  	}
    67  }
    68  
    69  // WaitUnlessPaused works like Wait, except it can return early if the
    70  // wait group is paused.  It returns whether it was paused with
    71  // outstanding work still left in the group.
    72  func (rwg *RepeatedWaitGroup) WaitUnlessPaused(ctx context.Context) (
    73  	bool, error) {
    74  	paused, isIdleCh, pauseCh := func() (bool, chan struct{}, chan struct{}) {
    75  		rwg.lock.Lock()
    76  		defer rwg.lock.Unlock()
    77  		if !rwg.paused && rwg.pauseCh == nil {
    78  			rwg.pauseCh = make(chan struct{})
    79  		}
    80  		return rwg.paused, rwg.isIdleCh, rwg.pauseCh
    81  	}()
    82  
    83  	if isIdleCh == nil {
    84  		return false, nil
    85  	}
    86  
    87  	if paused {
    88  		return true, nil
    89  	}
    90  
    91  	select {
    92  	case <-isIdleCh:
    93  		return false, nil
    94  	case <-pauseCh:
    95  		return true, nil
    96  	case <-ctx.Done():
    97  		return false, ctx.Err()
    98  	}
    99  }
   100  
   101  // Pause causes any current or future callers of `WaitUnlessPaused` to
   102  // return immediately.
   103  func (rwg *RepeatedWaitGroup) Pause() {
   104  	rwg.lock.Lock()
   105  	defer rwg.lock.Unlock()
   106  	rwg.paused = true
   107  	if rwg.pauseCh != nil {
   108  		close(rwg.pauseCh)
   109  		rwg.pauseCh = nil
   110  	}
   111  }
   112  
   113  // Resume unpauses the wait group, allowing future callers of
   114  // `WaitUnlessPaused` to wait until all the outstanding work is
   115  // completed.
   116  func (rwg *RepeatedWaitGroup) Resume() {
   117  	rwg.lock.Lock()
   118  	defer rwg.lock.Unlock()
   119  	if rwg.pauseCh != nil {
   120  		panic("Non-nil pauseCh on resume!")
   121  	}
   122  	rwg.paused = false
   123  }
   124  
   125  // Done indicates that one task has completed.
   126  func (rwg *RepeatedWaitGroup) Done() {
   127  	rwg.Add(-1)
   128  }