github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/internal/adapters/cloud/aws/adapt.go (about) 1 package aws 2 3 import ( 4 "context" 5 "fmt" 6 7 "github.com/khulnasoft-lab/defsec/pkg/concurrency" 8 "github.com/khulnasoft-lab/defsec/pkg/errs" 9 "github.com/khulnasoft-lab/defsec/pkg/types" 10 11 "github.com/khulnasoft-lab/defsec/pkg/debug" 12 13 "github.com/aws/aws-sdk-go-v2/service/sts" 14 15 "github.com/aws/aws-sdk-go-v2/aws" 16 "github.com/aws/aws-sdk-go-v2/aws/arn" 17 "github.com/aws/aws-sdk-go-v2/config" 18 "github.com/khulnasoft-lab/defsec/internal/adapters/cloud/options" 19 "github.com/khulnasoft-lab/defsec/pkg/progress" 20 "github.com/khulnasoft-lab/defsec/pkg/state" 21 ) 22 23 var registeredAdapters []ServiceAdapter 24 25 func RegisterServiceAdapter(adapter ServiceAdapter) { 26 for _, existing := range registeredAdapters { 27 if existing.Name() == adapter.Name() { 28 panic(fmt.Sprintf("duplicate service adapter: %s", adapter.Name())) 29 } 30 } 31 registeredAdapters = append(registeredAdapters, adapter) 32 } 33 34 type ServiceAdapter interface { 35 Name() string 36 Provider() string 37 Adapt(root *RootAdapter, state *state.State) error 38 } 39 40 type RootAdapter struct { 41 ctx context.Context 42 sessionCfg aws.Config 43 tracker progress.ServiceTracker 44 accountID string 45 currentService string 46 region string 47 debugWriter debug.Logger 48 concurrencyStrategy concurrency.Strategy 49 } 50 51 func NewRootAdapter(ctx context.Context, cfg aws.Config, tracker progress.ServiceTracker) *RootAdapter { 52 return &RootAdapter{ 53 ctx: ctx, 54 tracker: tracker, 55 sessionCfg: cfg, 56 region: cfg.Region, 57 } 58 } 59 60 func (a *RootAdapter) Region() string { 61 return a.region 62 } 63 64 func (a *RootAdapter) Debug(format string, args ...interface{}) { 65 a.debugWriter.Log(format, args...) 66 } 67 68 func (a *RootAdapter) ConcurrencyStrategy() concurrency.Strategy { 69 return a.concurrencyStrategy 70 } 71 72 func (a *RootAdapter) SessionConfig() aws.Config { 73 return a.sessionCfg 74 } 75 76 func (a *RootAdapter) Context() context.Context { 77 return a.ctx 78 } 79 80 func (a *RootAdapter) Tracker() progress.ServiceTracker { 81 return a.tracker 82 } 83 84 func (a *RootAdapter) CreateMetadata(resource string) types.Metadata { 85 86 // some services don't require region/account id in the ARN 87 // see https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#genref-aws-service-namespaces 88 namespace := a.accountID 89 region := a.region 90 switch a.currentService { 91 case "s3": 92 namespace = "" 93 region = "" 94 } 95 96 return a.CreateMetadataFromARN((arn.ARN{ 97 Partition: "aws", 98 Service: a.currentService, 99 Region: region, 100 AccountID: namespace, 101 Resource: resource, 102 }).String()) 103 } 104 105 func (a *RootAdapter) CreateMetadataFromARN(arn string) types.Metadata { 106 return types.NewRemoteMetadata(arn) 107 } 108 109 type resolver struct { 110 endpoint string 111 } 112 113 func (r *resolver) ResolveEndpoint(_, _ string, _ ...interface{}) (aws.Endpoint, error) { 114 return aws.Endpoint{ 115 URL: r.endpoint, 116 SigningRegion: "custom-signing-region", 117 Source: aws.EndpointSourceCustom, 118 }, nil 119 } 120 121 func createResolver(endpoint string) aws.EndpointResolverWithOptions { 122 return &resolver{ 123 endpoint: endpoint, 124 } 125 } 126 127 func AllServices() []string { 128 var services []string 129 for _, reg := range registeredAdapters { 130 services = append(services, reg.Name()) 131 } 132 return services 133 } 134 135 func Adapt(ctx context.Context, state *state.State, opt options.Options) error { 136 c := &RootAdapter{ 137 ctx: ctx, 138 tracker: opt.ProgressTracker, 139 debugWriter: opt.DebugWriter.Extend("adapt", "aws"), 140 concurrencyStrategy: opt.ConcurrencyStrategy, 141 } 142 143 cfg, err := config.LoadDefaultConfig(ctx) 144 if err != nil { 145 return err 146 } 147 148 c.sessionCfg = cfg 149 150 if opt.Region != "" { 151 c.Debug("Using region '%s'", opt.Region) 152 c.sessionCfg.Region = opt.Region 153 } 154 if opt.Endpoint != "" { 155 c.Debug("Using endpoint '%s'", opt.Endpoint) 156 c.sessionCfg.EndpointResolverWithOptions = createResolver(opt.Endpoint) 157 } 158 159 c.Debug("Discovering caller identity...") 160 stsClient := sts.NewFromConfig(c.sessionCfg) 161 result, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) 162 if err != nil { 163 return fmt.Errorf("failed to discover AWS caller identity: %w", err) 164 } 165 if result.Account == nil { 166 return fmt.Errorf("missing account id for aws account") 167 } 168 c.accountID = *result.Account 169 c.Debug("AWS account ID: %s", c.accountID) 170 171 if len(opt.Services) == 0 { 172 c.Debug("Preparing to run for all %d registered services...", len(registeredAdapters)) 173 opt.ProgressTracker.SetTotalServices(len(registeredAdapters)) 174 } else { 175 c.Debug("Preparing to run for %d filtered services...", len(opt.Services)) 176 opt.ProgressTracker.SetTotalServices(len(opt.Services)) 177 } 178 179 c.region = c.sessionCfg.Region 180 181 var adapterErrors []error 182 183 for _, adapter := range registeredAdapters { 184 if len(opt.Services) != 0 && !contains(opt.Services, adapter.Name()) { 185 continue 186 } 187 c.currentService = adapter.Name() 188 c.Debug("Running adapter for %s...", adapter.Name()) 189 opt.ProgressTracker.StartService(adapter.Name()) 190 191 if err := adapter.Adapt(c, state); err != nil { 192 c.Debug("Error occurred while running adapter for %s: %s", adapter.Name(), err) 193 adapterErrors = append(adapterErrors, fmt.Errorf("failed to run adapter for %s: %w", adapter.Name(), err)) 194 } 195 opt.ProgressTracker.FinishService() 196 } 197 198 if len(adapterErrors) > 0 { 199 return errs.NewAdapterError(adapterErrors) 200 } 201 202 return nil 203 } 204 205 func contains(services []string, service string) bool { 206 for _, s := range services { 207 if s == service { 208 return true 209 } 210 } 211 return false 212 }