bitbucket.org/ai69/amoy@v0.2.3/worker.go (about)

     1  package amoy
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  
     7  	"golang.org/x/sync/errgroup"
     8  )
     9  
    10  type (
    11  	// TaskRunFunc is a function that runs a task and returns the result with error.
    12  	TaskRunFunc func(ctx context.Context, taskID int) (interface{}, error)
    13  	// TaskResultFunc is a function that handles the result of a task.
    14  	TaskResultFunc func(ctx context.Context, taskID int, result interface{}, err error)
    15  )
    16  
    17  // ParallelTaskRun runs a given number of tasks and handles results in parallel.
    18  func ParallelTaskRun(ctx context.Context, workerNum, taskNum int, runFunc TaskRunFunc, doneFunc TaskResultFunc) error {
    19  	// precondition check
    20  	if workerNum < 1 {
    21  		return errors.New("invalid worker number")
    22  	}
    23  	if taskNum < 1 {
    24  		return errors.New("invalid task number")
    25  	}
    26  
    27  	// correct worker number
    28  	num := EnsureRange(workerNum, 1, taskNum)
    29  	g, ctx := errgroup.WithContext(ctx)
    30  
    31  	// send task
    32  	taskCh := make(chan int)
    33  	g.Go(func() error {
    34  		defer close(taskCh)
    35  		// Sending task IDs to channel
    36  		for i := 0; i < taskNum; i++ {
    37  			select {
    38  			case taskCh <- i:
    39  			case <-ctx.Done():
    40  				return ctx.Err()
    41  			}
    42  		}
    43  		return nil
    44  	})
    45  
    46  	// result channel
    47  	type result struct {
    48  		task int
    49  		res  interface{}
    50  		err  error
    51  	}
    52  	resultCh := make(chan result)
    53  
    54  	// creating a worker pool to run tasks
    55  	for i := 0; i < num; i++ {
    56  		g.Go(func() error {
    57  			for task := range taskCh {
    58  				res, err := runFunc(ctx, task)
    59  				select {
    60  				case <-ctx.Done():
    61  					return ctx.Err()
    62  				default:
    63  					resultCh <- result{task, res, err}
    64  				}
    65  			}
    66  			return nil
    67  		})
    68  	}
    69  
    70  	// receive results
    71  	go func() {
    72  		g.Wait()
    73  		close(resultCh)
    74  	}()
    75  	for res := range resultCh {
    76  		if doneFunc != nil {
    77  			doneFunc(ctx, res.task, res.res, res.err)
    78  		}
    79  	}
    80  
    81  	// waiting for all the goroutines to finish
    82  	return g.Wait()
    83  }