github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/parallel/walk.go (about)

     1  package parallel
     2  
     3  import (
     4  	"context"
     5  	"io/fs"
     6  
     7  	"go.uber.org/zap"
     8  	"golang.org/x/sync/errgroup"
     9  	"golang.org/x/xerrors"
    10  
    11  	dio "github.com/aquasecurity/go-dep-parser/pkg/io"
    12  	"github.com/devseccon/trivy/pkg/log"
    13  )
    14  
    15  const defaultParallel = 5
    16  
    17  type onFile[T any] func(string, fs.FileInfo, dio.ReadSeekerAt) (T, error)
    18  type onWalkResult[T any] func(T) error
    19  
    20  func WalkDir[T any](ctx context.Context, fsys fs.FS, root string, parallel int,
    21  	onFile onFile[T], onResult onWalkResult[T]) error {
    22  
    23  	g, ctx := errgroup.WithContext(ctx)
    24  	paths := make(chan string)
    25  
    26  	g.Go(func() error {
    27  		defer close(paths)
    28  		err := fs.WalkDir(fsys, root, func(path string, d fs.DirEntry, err error) error {
    29  			if err != nil {
    30  				return err
    31  			} else if !d.Type().IsRegular() {
    32  				return nil
    33  			}
    34  
    35  			// check if file is empty
    36  			info, err := d.Info()
    37  			if err != nil {
    38  				return err
    39  			} else if info.Size() == 0 {
    40  				log.Logger.Debugf("%s is empty, skip this file", path)
    41  				return nil
    42  			}
    43  
    44  			select {
    45  			case paths <- path:
    46  			case <-ctx.Done():
    47  				return ctx.Err()
    48  			}
    49  			return nil
    50  		})
    51  		if err != nil {
    52  			return xerrors.Errorf("walk error: %w", err)
    53  		}
    54  		return nil
    55  	})
    56  
    57  	// Start a fixed number of goroutines to read and digest files.
    58  	c := make(chan T)
    59  	if parallel == 0 {
    60  		parallel = defaultParallel
    61  	}
    62  	for i := 0; i < parallel; i++ {
    63  		g.Go(func() error {
    64  			for path := range paths {
    65  				if err := walk(ctx, fsys, path, c, onFile); err != nil {
    66  					return err
    67  				}
    68  			}
    69  			return nil
    70  		})
    71  	}
    72  	go func() {
    73  		_ = g.Wait()
    74  		close(c)
    75  	}()
    76  
    77  	for res := range c {
    78  		if err := onResult(res); err != nil {
    79  			return err
    80  		}
    81  	}
    82  	// Check whether any of the goroutines failed. Since g is accumulating the
    83  	// errors, we don't need to send them (or check for them) in the individual
    84  	// results sent on the channel.
    85  	if err := g.Wait(); err != nil {
    86  		return err
    87  	}
    88  	return nil
    89  }
    90  
    91  func walk[T any](ctx context.Context, fsys fs.FS, path string, c chan T, onFile onFile[T]) error {
    92  	f, err := fsys.Open(path)
    93  	if err != nil {
    94  		return xerrors.Errorf("file open error: %w", err)
    95  	}
    96  	defer f.Close()
    97  
    98  	info, err := f.Stat()
    99  	if err != nil {
   100  		return xerrors.Errorf("stat error: %w", err)
   101  	}
   102  
   103  	rsa, ok := f.(dio.ReadSeekerAt)
   104  	if !ok {
   105  		return xerrors.New("type assertion failed")
   106  	}
   107  	res, err := onFile(path, info, rsa)
   108  	if err != nil {
   109  		log.Logger.Debugw("Walk error", zap.String("file_path", path), zap.Error(err))
   110  		return nil
   111  	}
   112  
   113  	select {
   114  	case c <- res:
   115  	case <-ctx.Done():
   116  		return ctx.Err()
   117  	}
   118  	return nil
   119  }