github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/parallel/pipeline.go (about)

     1  package parallel
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/cheggaaa/pb/v3"
     7  	"golang.org/x/sync/errgroup"
     8  )
     9  
    10  const defaultWorkers = 5
    11  
    12  // Pipeline represents a structure for performing parallel processing.
    13  // T represents the input element type and U represents the output element type.
    14  type Pipeline[T, U any] struct {
    15  	numWorkers int
    16  	items      []T
    17  	onItem     onItem[T, U]
    18  	onResult   onResult[U]
    19  	progress   bool
    20  }
    21  
    22  // onItem represents a function type that takes an input element and returns an output element.
    23  type onItem[T, U any] func(context.Context, T) (U, error)
    24  
    25  // onResult represents a function type that takes an output element.
    26  type onResult[U any] func(U) error
    27  
    28  func NewPipeline[T, U any](numWorkers int, progress bool, items []T,
    29  	fn1 onItem[T, U], fn2 onResult[U]) Pipeline[T, U] {
    30  	if fn2 == nil {
    31  		// In case where there is no need to process the return values
    32  		fn2 = func(_ U) error { return nil }
    33  	}
    34  	if numWorkers == 0 {
    35  		numWorkers = defaultWorkers
    36  	}
    37  	return Pipeline[T, U]{
    38  		numWorkers: numWorkers,
    39  		progress:   progress,
    40  		items:      items,
    41  		onItem:     fn1,
    42  		onResult:   fn2,
    43  	}
    44  }
    45  
    46  // Do executes pipeline processing.
    47  // It exits when any error occurs.
    48  func (p *Pipeline[T, U]) Do(ctx context.Context) error {
    49  	// progress bar
    50  	var bar *pb.ProgressBar
    51  	if p.progress {
    52  		bar = pb.StartNew(len(p.items))
    53  		defer bar.Finish()
    54  	}
    55  
    56  	g, ctx := errgroup.WithContext(ctx)
    57  	itemCh := make(chan T)
    58  
    59  	// Start a goroutine to send input data
    60  	g.Go(func() error {
    61  		defer close(itemCh)
    62  		for _, item := range p.items {
    63  			if p.progress {
    64  				bar.Increment()
    65  			}
    66  			select {
    67  			case itemCh <- item:
    68  			case <-ctx.Done():
    69  				return ctx.Err()
    70  			}
    71  		}
    72  		return nil
    73  	})
    74  
    75  	// Generate a channel for sending output data
    76  	results := make(chan U)
    77  
    78  	// Start a fixed number of goroutines to process items.
    79  	for i := 0; i < p.numWorkers; i++ {
    80  		g.Go(func() error {
    81  			for item := range itemCh {
    82  				res, err := p.onItem(ctx, item)
    83  				if err != nil {
    84  					return err
    85  				}
    86  				select {
    87  				case results <- res:
    88  				case <-ctx.Done():
    89  					return ctx.Err()
    90  				}
    91  			}
    92  			return nil
    93  		})
    94  	}
    95  
    96  	go func() {
    97  		_ = g.Wait()
    98  		close(results)
    99  	}()
   100  
   101  	// Process output data received from the channel
   102  	for res := range results {
   103  		if err := p.onResult(res); err != nil {
   104  			return err
   105  		}
   106  	}
   107  
   108  	// Check whether any of the goroutines failed. Since g is accumulating the
   109  	// errors, we don't need to send them (or check for them) in the individual
   110  	// results sent on the channel.
   111  	if err := g.Wait(); err != nil {
   112  		return err
   113  	}
   114  	return nil
   115  }