github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/cloud/aws/commands/run.go (about)

     1  package commands
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"strings"
     7  
     8  	"github.com/aws/aws-sdk-go-v2/service/sts"
     9  	"golang.org/x/exp/slices"
    10  	"golang.org/x/xerrors"
    11  
    12  	"github.com/aquasecurity/trivy-aws/pkg/errs"
    13  	awsScanner "github.com/aquasecurity/trivy-aws/pkg/scanner"
    14  	"github.com/devseccon/trivy/pkg/cloud"
    15  	"github.com/devseccon/trivy/pkg/cloud/aws/config"
    16  	"github.com/devseccon/trivy/pkg/cloud/aws/scanner"
    17  	"github.com/devseccon/trivy/pkg/cloud/report"
    18  	"github.com/devseccon/trivy/pkg/commands/operation"
    19  	"github.com/devseccon/trivy/pkg/flag"
    20  	"github.com/devseccon/trivy/pkg/log"
    21  )
    22  
    23  var allSupportedServicesFunc = awsScanner.AllSupportedServices
    24  
    25  func getAccountIDAndRegion(ctx context.Context, region, endpoint string) (string, string, error) {
    26  	log.Logger.Debug("Looking for AWS credentials provider...")
    27  
    28  	cfg, err := config.LoadDefaultAWSConfig(ctx, region, endpoint)
    29  	if err != nil {
    30  		return "", "", err
    31  	}
    32  
    33  	svc := sts.NewFromConfig(cfg)
    34  
    35  	log.Logger.Debug("Looking up AWS caller identity...")
    36  	result, err := svc.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
    37  	if err != nil {
    38  		return "", "", xerrors.Errorf("failed to discover AWS caller identity: %w", err)
    39  	}
    40  	if result.Account == nil {
    41  		return "", "", xerrors.Errorf("missing account id for aws account")
    42  	}
    43  	log.Logger.Debugf("Verified AWS credentials for account %s!", *result.Account)
    44  	return *result.Account, cfg.Region, nil
    45  }
    46  
    47  func validateServicesInput(services, skipServices []string) error {
    48  	for _, s := range services {
    49  		for _, ss := range skipServices {
    50  			if s == ss {
    51  				return xerrors.Errorf("service: %s specified to both skip and include", s)
    52  			}
    53  		}
    54  	}
    55  	return nil
    56  }
    57  
    58  func processOptions(ctx context.Context, opt *flag.Options) error {
    59  	if err := validateServicesInput(opt.Services, opt.SkipServices); err != nil {
    60  		return err
    61  	}
    62  
    63  	// support comma separated services too
    64  	var splitServices []string
    65  	for _, service := range opt.Services {
    66  		splitServices = append(splitServices, strings.Split(service, ",")...)
    67  	}
    68  	opt.Services = splitServices
    69  
    70  	var splitSkipServices []string
    71  	for _, skipService := range opt.SkipServices {
    72  		splitSkipServices = append(splitSkipServices, strings.Split(skipService, ",")...)
    73  	}
    74  	opt.SkipServices = splitSkipServices
    75  
    76  	if len(opt.Services) != 1 && opt.ARN != "" {
    77  		return xerrors.Errorf("you must specify the single --service which the --arn relates to")
    78  	}
    79  
    80  	if opt.Account == "" || opt.Region == "" {
    81  		var err error
    82  		opt.Account, opt.Region, err = getAccountIDAndRegion(ctx, opt.Region, opt.Endpoint)
    83  		if err != nil {
    84  			return err
    85  		}
    86  	}
    87  
    88  	err := filterServices(opt)
    89  	if err != nil {
    90  		return err
    91  	}
    92  
    93  	log.Logger.Debug("scanning services: ", opt.Services)
    94  	return nil
    95  }
    96  
    97  func filterServices(opt *flag.Options) error {
    98  	switch {
    99  	case len(opt.Services) == 0 && len(opt.SkipServices) == 0:
   100  		log.Logger.Debug("No service(s) specified, scanning all services...")
   101  		opt.Services = allSupportedServicesFunc()
   102  	case len(opt.SkipServices) > 0:
   103  		log.Logger.Debug("excluding services: ", opt.SkipServices)
   104  		for _, s := range allSupportedServicesFunc() {
   105  			if slices.Contains(opt.SkipServices, s) {
   106  				continue
   107  			}
   108  			if !slices.Contains(opt.Services, s) {
   109  				opt.Services = append(opt.Services, s)
   110  			}
   111  		}
   112  	case len(opt.Services) > 0:
   113  		log.Logger.Debugf("Specific services were requested: [%s]...", strings.Join(opt.Services, ", "))
   114  		for _, service := range opt.Services {
   115  			var found bool
   116  			supported := allSupportedServicesFunc()
   117  			for _, allowed := range supported {
   118  				if allowed == service {
   119  					found = true
   120  					break
   121  				}
   122  			}
   123  			if !found {
   124  				return xerrors.Errorf("service '%s' is not currently supported - supported services are: %s", service, strings.Join(supported, ", "))
   125  			}
   126  		}
   127  	}
   128  	return nil
   129  }
   130  
   131  func Run(ctx context.Context, opt flag.Options) error {
   132  
   133  	ctx, cancel := context.WithTimeout(ctx, opt.GlobalOptions.Timeout)
   134  	defer cancel()
   135  
   136  	if err := log.InitLogger(opt.Debug, false); err != nil {
   137  		return xerrors.Errorf("logger error: %w", err)
   138  	}
   139  
   140  	var err error
   141  	defer func() {
   142  		if errors.Is(err, context.DeadlineExceeded) {
   143  			log.Logger.Warn("Increase --timeout value")
   144  		}
   145  	}()
   146  
   147  	if err := processOptions(ctx, &opt); err != nil {
   148  		return err
   149  	}
   150  
   151  	results, cached, err := scanner.NewScanner().Scan(ctx, opt)
   152  	if err != nil {
   153  		var aerr errs.AdapterError
   154  		if errors.As(err, &aerr) {
   155  			for _, e := range aerr.Errors() {
   156  				log.Logger.Warnf("Adapter error: %s", e)
   157  			}
   158  		} else {
   159  			return xerrors.Errorf("aws scan error: %w", err)
   160  		}
   161  	}
   162  
   163  	log.Logger.Debug("Writing report to output...")
   164  
   165  	res := results.GetFailed()
   166  	if opt.MisconfOptions.IncludeNonFailures {
   167  		res = results
   168  	}
   169  
   170  	r := report.New(cloud.ProviderAWS, opt.Account, opt.Region, res, opt.Services)
   171  	if err := report.Write(r, opt, cached); err != nil {
   172  		return xerrors.Errorf("unable to write results: %w", err)
   173  	}
   174  
   175  	operation.Exit(opt, r.Failed())
   176  	return nil
   177  }