github.com/grailbio/base@v0.0.11/sync/workerpool/workerpool.go (about) 1 package workerpool 2 3 import ( 4 "context" 5 "sync" 6 7 "github.com/grailbio/base/sync/multierror" 8 "v.io/x/lib/vlog" 9 ) 10 11 // Task provides an interface for an individual task. Tasks are executed by 12 // workers by calling the Do function. 13 type Task interface { 14 Do(grp *TaskGroup) error 15 } 16 17 // WorkerPool provides a mechanism for executing Tasks with a specific 18 // concurrency. A Task is an interface containing a single function Do. 19 // A TaskGroup allows Tasks to be grouped together so the 20 // parent process can wait for all Tasks in a TaskGroup to Wait. 21 // Tasks can create new Tasks and add them to the TaskGroup or new 22 // TaskGroups and add them to the WorkerPool. A simple example looks like 23 // this: 24 // 25 // wp := fileset.WorkerPool(context.Background(), 3) 26 // tg1 := wp.NewTaskGroup("context1") 27 // tg1.Enqueue(MyFirstTask, true) 28 // tg2 := wp.NewTaskGroup("context2") 29 // tg2.Enqueue(MyFourthTask, true) 30 // tg1.Enqueue(MySecondTask, true) 31 // tg2.Enqueue(MyFifthTask, true) 32 // tg1.Enqueue(MyThirdTask, true) 33 // tg1.Wait() 34 // tg2.Enqueue(MySixthTask, true) 35 // tg2.Wait() 36 // wp.Wait() 37 // 38 // TaskGroups can come and go until wp.Wait() has been called. Tasks can come 39 // and go in a TaskGroup until tg.Wait() has been called. All the Tasks 40 // in this example are executed by 3 go routines. 41 // 42 // Note: Each WorkerPool will create a goroutine to keep track of active 43 // TaskGroups. Each TaskGroup will create a goroutine to keep track of 44 // pending/active tasks. 45 type WorkerPool struct { 46 Ctx context.Context 47 Concurrency int 48 queue chan deliverable // Contains Tasks waiting to be executed. 49 ctxCounter sync.WaitGroup 50 } 51 52 // New creates a WorkerPool with the given concurrency. 53 // 54 // TODO(pknudsgaard): Should return a closure calling Wait. 55 func New(ctx context.Context, concurrency int) *WorkerPool { 56 result := WorkerPool{ 57 Ctx: ctx, 58 Concurrency: concurrency, 59 queue: make(chan deliverable, 10*concurrency), 60 } 61 62 for i := 0; i < concurrency; i++ { 63 go result.worker(result.queue) 64 } 65 66 return &result 67 } 68 69 // TaskGroup is used group Tasks together so the consumer can wait for a 70 // specific subgroup of Tasks to Wait. 71 type TaskGroup struct { 72 Name string 73 ErrHandler *multierror.Builder 74 Wp *WorkerPool 75 activity sync.WaitGroup // Count active tasks 76 } 77 78 // NewTaskGroup creates a TaskGroup for Tasks to be executed in. 79 // 80 // TODO(pknudsgaard): TaskGroup should have a context.Context which is 81 // separate from the WorkerPool context.Context. 82 // 83 // TODO(pknudsgaard): Should return a closure calling Wait. 84 func (wp *WorkerPool) NewTaskGroup(name string, errHandler *multierror.Builder) *TaskGroup { 85 vlog.VI(2).Infof("Creating TaskGroup: %s", name) 86 87 grp := &TaskGroup{ 88 Name: name, 89 ErrHandler: errHandler, 90 Wp: wp, 91 } 92 93 wp.ctxCounter.Add(1) 94 return grp 95 } 96 97 // Enqueue puts a Task in the queue. If block is true and the channel is full, 98 // then the function blocks. If block is false and the channel is full, then 99 // the function returns false. 100 func (grp *TaskGroup) Enqueue(t Task, block bool) bool { 101 var success bool 102 103 grp.activity.Add(1) 104 d := deliverable{grp: grp, t: t} 105 if block { 106 grp.Wp.queue <- d 107 success = true 108 } else { 109 select { 110 case grp.Wp.queue <- d: 111 success = true 112 default: 113 success = false 114 } 115 } 116 117 if !success { 118 grp.activity.Done() 119 } 120 121 return success 122 } 123 124 // Wait blocks until all Tasks in this TaskGroup have completed. 125 func (grp *TaskGroup) Wait() { 126 // Trigger the director in case we were already at 0. 127 grp.activity.Wait() 128 grp.Wp.ctxCounter.Done() 129 } 130 131 type deliverable struct { 132 grp *TaskGroup 133 t Task 134 } 135 136 // worker is the goroutine for a worker. It will continue to consume and 137 // execute tasks from the queue until the channel is closed or the TaskGroup is 138 // Done. 139 func (wp *WorkerPool) worker(dlv chan deliverable) { 140 vlog.VI(2).Infof("Starting worker") 141 defer vlog.VI(2).Infof("Ending worker") 142 143 for { 144 select { 145 case <-wp.Ctx.Done(): 146 for d := range dlv { 147 d.grp.activity.Done() 148 } 149 return 150 case d, ok := <-dlv: 151 if !ok { 152 // Channel is closed, quit worker. 153 return 154 } 155 d.grp.ErrHandler.Add(d.t.Do(d.grp)) 156 d.grp.activity.Done() 157 } 158 } 159 } 160 161 // Wait blocks until all TaskGroups in the WorkerPool have Waitd. 162 func (wp *WorkerPool) Wait() { 163 // Trigger the director in case we were already at 0: 164 wp.ctxCounter.Wait() 165 close(wp.queue) 166 } 167 168 // Err returns the context.Context error to determine if WorkerPool Waitd 169 // due to the context. 170 func (wp *WorkerPool) Err() error { 171 return wp.Ctx.Err() 172 }