bitbucket.org/ai69/amoy@v0.2.3/worker.go (about) 1 package amoy 2 3 import ( 4 "context" 5 "errors" 6 7 "golang.org/x/sync/errgroup" 8 ) 9 10 type ( 11 // TaskRunFunc is a function that runs a task and returns the result with error. 12 TaskRunFunc func(ctx context.Context, taskID int) (interface{}, error) 13 // TaskResultFunc is a function that handles the result of a task. 14 TaskResultFunc func(ctx context.Context, taskID int, result interface{}, err error) 15 ) 16 17 // ParallelTaskRun runs a given number of tasks and handles results in parallel. 18 func ParallelTaskRun(ctx context.Context, workerNum, taskNum int, runFunc TaskRunFunc, doneFunc TaskResultFunc) error { 19 // precondition check 20 if workerNum < 1 { 21 return errors.New("invalid worker number") 22 } 23 if taskNum < 1 { 24 return errors.New("invalid task number") 25 } 26 27 // correct worker number 28 num := EnsureRange(workerNum, 1, taskNum) 29 g, ctx := errgroup.WithContext(ctx) 30 31 // send task 32 taskCh := make(chan int) 33 g.Go(func() error { 34 defer close(taskCh) 35 // Sending task IDs to channel 36 for i := 0; i < taskNum; i++ { 37 select { 38 case taskCh <- i: 39 case <-ctx.Done(): 40 return ctx.Err() 41 } 42 } 43 return nil 44 }) 45 46 // result channel 47 type result struct { 48 task int 49 res interface{} 50 err error 51 } 52 resultCh := make(chan result) 53 54 // creating a worker pool to run tasks 55 for i := 0; i < num; i++ { 56 g.Go(func() error { 57 for task := range taskCh { 58 res, err := runFunc(ctx, task) 59 select { 60 case <-ctx.Done(): 61 return ctx.Err() 62 default: 63 resultCh <- result{task, res, err} 64 } 65 } 66 return nil 67 }) 68 } 69 70 // receive results 71 go func() { 72 g.Wait() 73 close(resultCh) 74 }() 75 for res := range resultCh { 76 if doneFunc != nil { 77 doneFunc(ctx, res.task, res.res, res.err) 78 } 79 } 80 81 // waiting for all the goroutines to finish 82 return g.Wait() 83 }