github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/parallel/pipeline.go (about) 1 package parallel 2 3 import ( 4 "context" 5 6 "github.com/cheggaaa/pb/v3" 7 "golang.org/x/sync/errgroup" 8 ) 9 10 const defaultWorkers = 5 11 12 // Pipeline represents a structure for performing parallel processing. 13 // T represents the input element type and U represents the output element type. 14 type Pipeline[T, U any] struct { 15 numWorkers int 16 items []T 17 onItem onItem[T, U] 18 onResult onResult[U] 19 progress bool 20 } 21 22 // onItem represents a function type that takes an input element and returns an output element. 23 type onItem[T, U any] func(context.Context, T) (U, error) 24 25 // onResult represents a function type that takes an output element. 26 type onResult[U any] func(U) error 27 28 func NewPipeline[T, U any](numWorkers int, progress bool, items []T, 29 fn1 onItem[T, U], fn2 onResult[U]) Pipeline[T, U] { 30 if fn2 == nil { 31 // In case where there is no need to process the return values 32 fn2 = func(_ U) error { return nil } 33 } 34 if numWorkers == 0 { 35 numWorkers = defaultWorkers 36 } 37 return Pipeline[T, U]{ 38 numWorkers: numWorkers, 39 progress: progress, 40 items: items, 41 onItem: fn1, 42 onResult: fn2, 43 } 44 } 45 46 // Do executes pipeline processing. 47 // It exits when any error occurs. 48 func (p *Pipeline[T, U]) Do(ctx context.Context) error { 49 // progress bar 50 var bar *pb.ProgressBar 51 if p.progress { 52 bar = pb.StartNew(len(p.items)) 53 defer bar.Finish() 54 } 55 56 g, ctx := errgroup.WithContext(ctx) 57 itemCh := make(chan T) 58 59 // Start a goroutine to send input data 60 g.Go(func() error { 61 defer close(itemCh) 62 for _, item := range p.items { 63 if p.progress { 64 bar.Increment() 65 } 66 select { 67 case itemCh <- item: 68 case <-ctx.Done(): 69 return ctx.Err() 70 } 71 } 72 return nil 73 }) 74 75 // Generate a channel for sending output data 76 results := make(chan U) 77 78 // Start a fixed number of goroutines to process items. 79 for i := 0; i < p.numWorkers; i++ { 80 g.Go(func() error { 81 for item := range itemCh { 82 res, err := p.onItem(ctx, item) 83 if err != nil { 84 return err 85 } 86 select { 87 case results <- res: 88 case <-ctx.Done(): 89 return ctx.Err() 90 } 91 } 92 return nil 93 }) 94 } 95 96 go func() { 97 _ = g.Wait() 98 close(results) 99 }() 100 101 // Process output data received from the channel 102 for res := range results { 103 if err := p.onResult(res); err != nil { 104 return err 105 } 106 } 107 108 // Check whether any of the goroutines failed. Since g is accumulating the 109 // errors, we don't need to send them (or check for them) in the individual 110 // results sent on the channel. 111 if err := g.Wait(); err != nil { 112 return err 113 } 114 return nil 115 }