github.com/grailbio/base@v0.0.11/sync/workerpool/workerpool.go (about)

     1  package workerpool
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	"github.com/grailbio/base/sync/multierror"
     8  	"v.io/x/lib/vlog"
     9  )
    10  
    11  // Task provides an interface for an individual task. Tasks are executed by
    12  // workers by calling the Do function.
    13  type Task interface {
    14  	Do(grp *TaskGroup) error
    15  }
    16  
    17  // WorkerPool provides a mechanism for executing Tasks with a specific
    18  // concurrency. A Task is an interface containing a single function Do.
    19  // A TaskGroup allows Tasks to be grouped together so the
    20  // parent process can wait for all Tasks in a TaskGroup to Wait.
    21  // Tasks can create new Tasks and add them to the TaskGroup or new
    22  // TaskGroups and add them to the WorkerPool. A simple example looks like
    23  // this:
    24  //
    25  // wp := fileset.WorkerPool(context.Background(), 3)
    26  // tg1 := wp.NewTaskGroup("context1")
    27  // tg1.Enqueue(MyFirstTask, true)
    28  // tg2 := wp.NewTaskGroup("context2")
    29  // tg2.Enqueue(MyFourthTask, true)
    30  // tg1.Enqueue(MySecondTask, true)
    31  // tg2.Enqueue(MyFifthTask, true)
    32  // tg1.Enqueue(MyThirdTask, true)
    33  // tg1.Wait()
    34  // tg2.Enqueue(MySixthTask, true)
    35  // tg2.Wait()
    36  // wp.Wait()
    37  //
    38  // TaskGroups can come and go until wp.Wait() has been called. Tasks can come
    39  // and go in a TaskGroup until tg.Wait() has been called. All the Tasks
    40  // in this example are executed by 3 go routines.
    41  //
    42  // Note: Each WorkerPool will create a goroutine to keep track of active
    43  // TaskGroups. Each TaskGroup will create a goroutine to keep track of
    44  // pending/active tasks.
    45  type WorkerPool struct {
    46  	Ctx         context.Context
    47  	Concurrency int
    48  	queue       chan deliverable // Contains Tasks waiting to be executed.
    49  	ctxCounter  sync.WaitGroup
    50  }
    51  
    52  // New creates a WorkerPool with the given concurrency.
    53  //
    54  // TODO(pknudsgaard): Should return a closure calling Wait.
    55  func New(ctx context.Context, concurrency int) *WorkerPool {
    56  	result := WorkerPool{
    57  		Ctx:         ctx,
    58  		Concurrency: concurrency,
    59  		queue:       make(chan deliverable, 10*concurrency),
    60  	}
    61  
    62  	for i := 0; i < concurrency; i++ {
    63  		go result.worker(result.queue)
    64  	}
    65  
    66  	return &result
    67  }
    68  
    69  // TaskGroup is used group Tasks together so the consumer can wait for a
    70  // specific subgroup of Tasks to Wait.
    71  type TaskGroup struct {
    72  	Name       string
    73  	ErrHandler *multierror.Builder
    74  	Wp         *WorkerPool
    75  	activity   sync.WaitGroup // Count active tasks
    76  }
    77  
    78  // NewTaskGroup creates a TaskGroup for Tasks to be executed in.
    79  //
    80  // TODO(pknudsgaard): TaskGroup should have a context.Context which is
    81  // separate from the WorkerPool context.Context.
    82  //
    83  // TODO(pknudsgaard): Should return a closure calling Wait.
    84  func (wp *WorkerPool) NewTaskGroup(name string, errHandler *multierror.Builder) *TaskGroup {
    85  	vlog.VI(2).Infof("Creating TaskGroup: %s", name)
    86  
    87  	grp := &TaskGroup{
    88  		Name:       name,
    89  		ErrHandler: errHandler,
    90  		Wp:         wp,
    91  	}
    92  
    93  	wp.ctxCounter.Add(1)
    94  	return grp
    95  }
    96  
    97  // Enqueue puts a Task in the queue. If block is true and the channel is full,
    98  // then the function blocks. If block is false and the channel is full, then
    99  // the function returns false.
   100  func (grp *TaskGroup) Enqueue(t Task, block bool) bool {
   101  	var success bool
   102  
   103  	grp.activity.Add(1)
   104  	d := deliverable{grp: grp, t: t}
   105  	if block {
   106  		grp.Wp.queue <- d
   107  		success = true
   108  	} else {
   109  		select {
   110  		case grp.Wp.queue <- d:
   111  			success = true
   112  		default:
   113  			success = false
   114  		}
   115  	}
   116  
   117  	if !success {
   118  		grp.activity.Done()
   119  	}
   120  
   121  	return success
   122  }
   123  
   124  // Wait blocks until all Tasks in this TaskGroup have completed.
   125  func (grp *TaskGroup) Wait() {
   126  	// Trigger the director in case we were already at 0.
   127  	grp.activity.Wait()
   128  	grp.Wp.ctxCounter.Done()
   129  }
   130  
   131  type deliverable struct {
   132  	grp *TaskGroup
   133  	t   Task
   134  }
   135  
   136  // worker is the goroutine for a worker. It will continue to consume and
   137  // execute tasks from the queue until the channel is closed or the TaskGroup is
   138  // Done.
   139  func (wp *WorkerPool) worker(dlv chan deliverable) {
   140  	vlog.VI(2).Infof("Starting worker")
   141  	defer vlog.VI(2).Infof("Ending worker")
   142  
   143  	for {
   144  		select {
   145  		case <-wp.Ctx.Done():
   146  			for d := range dlv {
   147  				d.grp.activity.Done()
   148  			}
   149  			return
   150  		case d, ok := <-dlv:
   151  			if !ok {
   152  				// Channel is closed, quit worker.
   153  				return
   154  			}
   155  			d.grp.ErrHandler.Add(d.t.Do(d.grp))
   156  			d.grp.activity.Done()
   157  		}
   158  	}
   159  }
   160  
   161  // Wait blocks until all TaskGroups in the WorkerPool have Waitd.
   162  func (wp *WorkerPool) Wait() {
   163  	// Trigger the director in case we were already at 0:
   164  	wp.ctxCounter.Wait()
   165  	close(wp.queue)
   166  }
   167  
   168  // Err returns the context.Context error to determine if WorkerPool Waitd
   169  // due to the context.
   170  func (wp *WorkerPool) Err() error {
   171  	return wp.Ctx.Err()
   172  }