github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/cloud/aws/commands/run.go (about) 1 package commands 2 3 import ( 4 "context" 5 "errors" 6 "strings" 7 8 "github.com/aws/aws-sdk-go-v2/service/sts" 9 "golang.org/x/exp/slices" 10 "golang.org/x/xerrors" 11 12 "github.com/aquasecurity/trivy-aws/pkg/errs" 13 awsScanner "github.com/aquasecurity/trivy-aws/pkg/scanner" 14 "github.com/devseccon/trivy/pkg/cloud" 15 "github.com/devseccon/trivy/pkg/cloud/aws/config" 16 "github.com/devseccon/trivy/pkg/cloud/aws/scanner" 17 "github.com/devseccon/trivy/pkg/cloud/report" 18 "github.com/devseccon/trivy/pkg/commands/operation" 19 "github.com/devseccon/trivy/pkg/flag" 20 "github.com/devseccon/trivy/pkg/log" 21 ) 22 23 var allSupportedServicesFunc = awsScanner.AllSupportedServices 24 25 func getAccountIDAndRegion(ctx context.Context, region, endpoint string) (string, string, error) { 26 log.Logger.Debug("Looking for AWS credentials provider...") 27 28 cfg, err := config.LoadDefaultAWSConfig(ctx, region, endpoint) 29 if err != nil { 30 return "", "", err 31 } 32 33 svc := sts.NewFromConfig(cfg) 34 35 log.Logger.Debug("Looking up AWS caller identity...") 36 result, err := svc.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) 37 if err != nil { 38 return "", "", xerrors.Errorf("failed to discover AWS caller identity: %w", err) 39 } 40 if result.Account == nil { 41 return "", "", xerrors.Errorf("missing account id for aws account") 42 } 43 log.Logger.Debugf("Verified AWS credentials for account %s!", *result.Account) 44 return *result.Account, cfg.Region, nil 45 } 46 47 func validateServicesInput(services, skipServices []string) error { 48 for _, s := range services { 49 for _, ss := range skipServices { 50 if s == ss { 51 return xerrors.Errorf("service: %s specified to both skip and include", s) 52 } 53 } 54 } 55 return nil 56 } 57 58 func processOptions(ctx context.Context, opt *flag.Options) error { 59 if err := validateServicesInput(opt.Services, opt.SkipServices); err != nil { 60 return err 61 } 62 63 // support comma separated services too 64 var splitServices []string 65 for _, service := range opt.Services { 66 splitServices = append(splitServices, strings.Split(service, ",")...) 67 } 68 opt.Services = splitServices 69 70 var splitSkipServices []string 71 for _, skipService := range opt.SkipServices { 72 splitSkipServices = append(splitSkipServices, strings.Split(skipService, ",")...) 73 } 74 opt.SkipServices = splitSkipServices 75 76 if len(opt.Services) != 1 && opt.ARN != "" { 77 return xerrors.Errorf("you must specify the single --service which the --arn relates to") 78 } 79 80 if opt.Account == "" || opt.Region == "" { 81 var err error 82 opt.Account, opt.Region, err = getAccountIDAndRegion(ctx, opt.Region, opt.Endpoint) 83 if err != nil { 84 return err 85 } 86 } 87 88 err := filterServices(opt) 89 if err != nil { 90 return err 91 } 92 93 log.Logger.Debug("scanning services: ", opt.Services) 94 return nil 95 } 96 97 func filterServices(opt *flag.Options) error { 98 switch { 99 case len(opt.Services) == 0 && len(opt.SkipServices) == 0: 100 log.Logger.Debug("No service(s) specified, scanning all services...") 101 opt.Services = allSupportedServicesFunc() 102 case len(opt.SkipServices) > 0: 103 log.Logger.Debug("excluding services: ", opt.SkipServices) 104 for _, s := range allSupportedServicesFunc() { 105 if slices.Contains(opt.SkipServices, s) { 106 continue 107 } 108 if !slices.Contains(opt.Services, s) { 109 opt.Services = append(opt.Services, s) 110 } 111 } 112 case len(opt.Services) > 0: 113 log.Logger.Debugf("Specific services were requested: [%s]...", strings.Join(opt.Services, ", ")) 114 for _, service := range opt.Services { 115 var found bool 116 supported := allSupportedServicesFunc() 117 for _, allowed := range supported { 118 if allowed == service { 119 found = true 120 break 121 } 122 } 123 if !found { 124 return xerrors.Errorf("service '%s' is not currently supported - supported services are: %s", service, strings.Join(supported, ", ")) 125 } 126 } 127 } 128 return nil 129 } 130 131 func Run(ctx context.Context, opt flag.Options) error { 132 133 ctx, cancel := context.WithTimeout(ctx, opt.GlobalOptions.Timeout) 134 defer cancel() 135 136 if err := log.InitLogger(opt.Debug, false); err != nil { 137 return xerrors.Errorf("logger error: %w", err) 138 } 139 140 var err error 141 defer func() { 142 if errors.Is(err, context.DeadlineExceeded) { 143 log.Logger.Warn("Increase --timeout value") 144 } 145 }() 146 147 if err := processOptions(ctx, &opt); err != nil { 148 return err 149 } 150 151 results, cached, err := scanner.NewScanner().Scan(ctx, opt) 152 if err != nil { 153 var aerr errs.AdapterError 154 if errors.As(err, &aerr) { 155 for _, e := range aerr.Errors() { 156 log.Logger.Warnf("Adapter error: %s", e) 157 } 158 } else { 159 return xerrors.Errorf("aws scan error: %w", err) 160 } 161 } 162 163 log.Logger.Debug("Writing report to output...") 164 165 res := results.GetFailed() 166 if opt.MisconfOptions.IncludeNonFailures { 167 res = results 168 } 169 170 r := report.New(cloud.ProviderAWS, opt.Account, opt.Region, res, opt.Services) 171 if err := report.Write(r, opt, cached); err != nil { 172 return xerrors.Errorf("unable to write results: %w", err) 173 } 174 175 operation.Exit(opt, r.Failed()) 176 return nil 177 }