github.com/reusee/pr2@v0.0.0-20230630035947-72a20ff5e864/wait_group.go (about) 1 package pr2 2 3 import ( 4 "context" 5 "sync" 6 "time" 7 8 "github.com/reusee/e5" 9 ) 10 11 type waitGroupKey struct{} 12 13 var WaitGroupKey = waitGroupKey{} 14 15 type WaitGroup struct { 16 ctx context.Context 17 wg *sync.WaitGroup 18 cancel context.CancelFunc 19 } 20 21 func NewWaitGroup(ctx context.Context) *WaitGroup { 22 23 if v := ctx.Value(WaitGroupKey); v != nil { 24 // if there is parent wait group, derive from it 25 parentWaitGroup := v.(*WaitGroup) 26 parentWaitGroup.wg.Add(1) 27 ctx, cancel := context.WithCancel(ctx) 28 wg := &WaitGroup{ 29 wg: new(sync.WaitGroup), 30 cancel: cancel, 31 } 32 ctx = context.WithValue(ctx, WaitGroupKey, wg) 33 wg.ctx = ctx 34 go func() { 35 <-ctx.Done() 36 wg.wg.Wait() 37 parentWaitGroup.wg.Done() 38 }() 39 return wg 40 } 41 42 // new root wait group 43 ctx, cancel := context.WithCancel(ctx) 44 wg := &WaitGroup{ 45 wg: new(sync.WaitGroup), 46 cancel: cancel, 47 } 48 ctx = context.WithValue(ctx, WaitGroupKey, wg) 49 wg.ctx = ctx 50 return wg 51 } 52 53 func GetWaitGroup(ctx context.Context) *WaitGroup { 54 if v := ctx.Value(WaitGroupKey); v != nil { 55 return v.(*WaitGroup) 56 } 57 return nil 58 } 59 60 func (w *WaitGroup) Cancel() { 61 w.cancel() 62 } 63 64 func (w *WaitGroup) Add() (done func()) { 65 select { 66 case <-w.ctx.Done(): 67 e5.Throw(context.Canceled) 68 default: 69 } 70 w.wg.Add(1) 71 var doneOnce sync.Once 72 return func() { 73 doneOnce.Do(func() { 74 w.wg.Done() 75 }) 76 } 77 } 78 79 func (w *WaitGroup) Wait() { 80 w.wg.Wait() 81 } 82 83 var _ context.Context = new(WaitGroup) 84 85 func (w *WaitGroup) Done() <-chan struct{} { 86 return w.ctx.Done() 87 } 88 89 func (w *WaitGroup) Err() error { 90 return w.ctx.Err() 91 } 92 93 func (w *WaitGroup) Deadline() (deadline time.Time, ok bool) { 94 return w.ctx.Deadline() 95 } 96 97 func (w *WaitGroup) Value(key any) any { 98 return w.ctx.Value(key) 99 } 100 101 func (w *WaitGroup) Go(fn func()) { 102 done := w.Add() 103 go func() { 104 defer done() 105 fn() 106 }() 107 }