go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/providers-sdk/v1/util/jobpool/jobpool.go (about)

     1  // Copyright (c) Mondoo, Inc.
     2  // SPDX-License-Identifier: BUSL-1.1
     3  
     4  package jobpool
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"sync"
    10  )
    11  
    12  type JobResult interface{}
    13  
    14  // Job encapsulates a work item that should go in a work pool.
    15  type Job struct {
    16  	Err    error
    17  	Result JobResult
    18  	f      func() (JobResult, error)
    19  }
    20  
    21  // NewJob initializes a new job based on given params.
    22  func NewJob(f func() (JobResult, error)) *Job {
    23  	return &Job{f: f}
    24  }
    25  
    26  // Run runs a job and does appropriate accounting via a given sync.WorkGroup
    27  func (t *Job) Run(wg *sync.WaitGroup) {
    28  	if t.f == nil {
    29  		t.Err = fmt.Errorf("no function to run in jobpool: %s", t.Err)
    30  	} else {
    31  		t.Result, t.Err = t.f()
    32  	}
    33  	wg.Done()
    34  }
    35  
    36  // Pool is a worker group that runs a number of jobs
    37  type Pool struct {
    38  	Jobs []*Job
    39  
    40  	concurrency int // the amount of jobs to run concurrently
    41  	jobsChan    chan *Job
    42  	wg          sync.WaitGroup
    43  }
    44  
    45  // CreatePool takes a slice of jobs and a concurrency int, creating a channel to handle the jobs
    46  func CreatePool(jobs []*Job, concurrency int) *Pool {
    47  	return &Pool{
    48  		Jobs:        jobs,
    49  		concurrency: concurrency,
    50  		jobsChan:    make(chan *Job),
    51  	}
    52  }
    53  
    54  // HasErrors returns a bool base on the existence of errors in the job.
    55  func (p *Pool) HasErrors() bool {
    56  	for _, job := range p.Jobs {
    57  		if job.Err != nil {
    58  			return true
    59  		}
    60  	}
    61  	return false
    62  }
    63  
    64  // GetErrors returns all errors from jobs run.
    65  func (p *Pool) GetErrors() error {
    66  	var err error
    67  	for _, job := range p.Jobs {
    68  		if job.Err != nil {
    69  			if err == nil {
    70  				err = job.Err
    71  			} else {
    72  				err = errors.Join(err, job.Err)
    73  			}
    74  		}
    75  	}
    76  	return err
    77  }
    78  
    79  // Run runs all work within the pool and blocks until it's finished.
    80  func (p *Pool) Run() {
    81  	for i := 0; i < p.concurrency; i++ {
    82  		go p.work()
    83  	}
    84  
    85  	p.wg.Add(len(p.Jobs))
    86  	for i := range p.Jobs {
    87  		p.jobsChan <- p.Jobs[i]
    88  	}
    89  
    90  	// all workers return
    91  	close(p.jobsChan)
    92  
    93  	p.wg.Wait()
    94  }
    95  
    96  // The work loop for any single goroutine.
    97  func (p *Pool) work() {
    98  	for job := range p.jobsChan {
    99  		job.Run(&p.wg)
   100  	}
   101  }