github.com/reusee/pr2@v0.0.0-20230630035947-72a20ff5e864/wait_group.go (about)

     1  package pr2
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/reusee/e5"
     9  )
    10  
    11  type waitGroupKey struct{}
    12  
    13  var WaitGroupKey = waitGroupKey{}
    14  
    15  type WaitGroup struct {
    16  	ctx    context.Context
    17  	wg     *sync.WaitGroup
    18  	cancel context.CancelFunc
    19  }
    20  
    21  func NewWaitGroup(ctx context.Context) *WaitGroup {
    22  
    23  	if v := ctx.Value(WaitGroupKey); v != nil {
    24  		// if there is parent wait group, derive from it
    25  		parentWaitGroup := v.(*WaitGroup)
    26  		parentWaitGroup.wg.Add(1)
    27  		ctx, cancel := context.WithCancel(ctx)
    28  		wg := &WaitGroup{
    29  			wg:     new(sync.WaitGroup),
    30  			cancel: cancel,
    31  		}
    32  		ctx = context.WithValue(ctx, WaitGroupKey, wg)
    33  		wg.ctx = ctx
    34  		go func() {
    35  			<-ctx.Done()
    36  			wg.wg.Wait()
    37  			parentWaitGroup.wg.Done()
    38  		}()
    39  		return wg
    40  	}
    41  
    42  	// new root wait group
    43  	ctx, cancel := context.WithCancel(ctx)
    44  	wg := &WaitGroup{
    45  		wg:     new(sync.WaitGroup),
    46  		cancel: cancel,
    47  	}
    48  	ctx = context.WithValue(ctx, WaitGroupKey, wg)
    49  	wg.ctx = ctx
    50  	return wg
    51  }
    52  
    53  func GetWaitGroup(ctx context.Context) *WaitGroup {
    54  	if v := ctx.Value(WaitGroupKey); v != nil {
    55  		return v.(*WaitGroup)
    56  	}
    57  	return nil
    58  }
    59  
    60  func (w *WaitGroup) Cancel() {
    61  	w.cancel()
    62  }
    63  
    64  func (w *WaitGroup) Add() (done func()) {
    65  	select {
    66  	case <-w.ctx.Done():
    67  		e5.Throw(context.Canceled)
    68  	default:
    69  	}
    70  	w.wg.Add(1)
    71  	var doneOnce sync.Once
    72  	return func() {
    73  		doneOnce.Do(func() {
    74  			w.wg.Done()
    75  		})
    76  	}
    77  }
    78  
    79  func (w *WaitGroup) Wait() {
    80  	w.wg.Wait()
    81  }
    82  
    83  var _ context.Context = new(WaitGroup)
    84  
    85  func (w *WaitGroup) Done() <-chan struct{} {
    86  	return w.ctx.Done()
    87  }
    88  
    89  func (w *WaitGroup) Err() error {
    90  	return w.ctx.Err()
    91  }
    92  
    93  func (w *WaitGroup) Deadline() (deadline time.Time, ok bool) {
    94  	return w.ctx.Deadline()
    95  }
    96  
    97  func (w *WaitGroup) Value(key any) any {
    98  	return w.ctx.Value(key)
    99  }
   100  
   101  func (w *WaitGroup) Go(fn func()) {
   102  	done := w.Add()
   103  	go func() {
   104  		defer done()
   105  		fn()
   106  	}()
   107  }