github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/internal/adapters/cloud/aws/adapt.go (about)

     1  package aws
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/khulnasoft-lab/defsec/pkg/concurrency"
     8  	"github.com/khulnasoft-lab/defsec/pkg/errs"
     9  	"github.com/khulnasoft-lab/defsec/pkg/types"
    10  
    11  	"github.com/khulnasoft-lab/defsec/pkg/debug"
    12  
    13  	"github.com/aws/aws-sdk-go-v2/service/sts"
    14  
    15  	"github.com/aws/aws-sdk-go-v2/aws"
    16  	"github.com/aws/aws-sdk-go-v2/aws/arn"
    17  	"github.com/aws/aws-sdk-go-v2/config"
    18  	"github.com/khulnasoft-lab/defsec/internal/adapters/cloud/options"
    19  	"github.com/khulnasoft-lab/defsec/pkg/progress"
    20  	"github.com/khulnasoft-lab/defsec/pkg/state"
    21  )
    22  
    23  var registeredAdapters []ServiceAdapter
    24  
    25  func RegisterServiceAdapter(adapter ServiceAdapter) {
    26  	for _, existing := range registeredAdapters {
    27  		if existing.Name() == adapter.Name() {
    28  			panic(fmt.Sprintf("duplicate service adapter: %s", adapter.Name()))
    29  		}
    30  	}
    31  	registeredAdapters = append(registeredAdapters, adapter)
    32  }
    33  
    34  type ServiceAdapter interface {
    35  	Name() string
    36  	Provider() string
    37  	Adapt(root *RootAdapter, state *state.State) error
    38  }
    39  
    40  type RootAdapter struct {
    41  	ctx                 context.Context
    42  	sessionCfg          aws.Config
    43  	tracker             progress.ServiceTracker
    44  	accountID           string
    45  	currentService      string
    46  	region              string
    47  	debugWriter         debug.Logger
    48  	concurrencyStrategy concurrency.Strategy
    49  }
    50  
    51  func NewRootAdapter(ctx context.Context, cfg aws.Config, tracker progress.ServiceTracker) *RootAdapter {
    52  	return &RootAdapter{
    53  		ctx:        ctx,
    54  		tracker:    tracker,
    55  		sessionCfg: cfg,
    56  		region:     cfg.Region,
    57  	}
    58  }
    59  
    60  func (a *RootAdapter) Region() string {
    61  	return a.region
    62  }
    63  
    64  func (a *RootAdapter) Debug(format string, args ...interface{}) {
    65  	a.debugWriter.Log(format, args...)
    66  }
    67  
    68  func (a *RootAdapter) ConcurrencyStrategy() concurrency.Strategy {
    69  	return a.concurrencyStrategy
    70  }
    71  
    72  func (a *RootAdapter) SessionConfig() aws.Config {
    73  	return a.sessionCfg
    74  }
    75  
    76  func (a *RootAdapter) Context() context.Context {
    77  	return a.ctx
    78  }
    79  
    80  func (a *RootAdapter) Tracker() progress.ServiceTracker {
    81  	return a.tracker
    82  }
    83  
    84  func (a *RootAdapter) CreateMetadata(resource string) types.Metadata {
    85  
    86  	// some services don't require region/account id in the ARN
    87  	// see https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#genref-aws-service-namespaces
    88  	namespace := a.accountID
    89  	region := a.region
    90  	switch a.currentService {
    91  	case "s3":
    92  		namespace = ""
    93  		region = ""
    94  	}
    95  
    96  	return a.CreateMetadataFromARN((arn.ARN{
    97  		Partition: "aws",
    98  		Service:   a.currentService,
    99  		Region:    region,
   100  		AccountID: namespace,
   101  		Resource:  resource,
   102  	}).String())
   103  }
   104  
   105  func (a *RootAdapter) CreateMetadataFromARN(arn string) types.Metadata {
   106  	return types.NewRemoteMetadata(arn)
   107  }
   108  
   109  type resolver struct {
   110  	endpoint string
   111  }
   112  
   113  func (r *resolver) ResolveEndpoint(_, _ string, _ ...interface{}) (aws.Endpoint, error) {
   114  	return aws.Endpoint{
   115  		URL:           r.endpoint,
   116  		SigningRegion: "custom-signing-region",
   117  		Source:        aws.EndpointSourceCustom,
   118  	}, nil
   119  }
   120  
   121  func createResolver(endpoint string) aws.EndpointResolverWithOptions {
   122  	return &resolver{
   123  		endpoint: endpoint,
   124  	}
   125  }
   126  
   127  func AllServices() []string {
   128  	var services []string
   129  	for _, reg := range registeredAdapters {
   130  		services = append(services, reg.Name())
   131  	}
   132  	return services
   133  }
   134  
   135  func Adapt(ctx context.Context, state *state.State, opt options.Options) error {
   136  	c := &RootAdapter{
   137  		ctx:                 ctx,
   138  		tracker:             opt.ProgressTracker,
   139  		debugWriter:         opt.DebugWriter.Extend("adapt", "aws"),
   140  		concurrencyStrategy: opt.ConcurrencyStrategy,
   141  	}
   142  
   143  	cfg, err := config.LoadDefaultConfig(ctx)
   144  	if err != nil {
   145  		return err
   146  	}
   147  
   148  	c.sessionCfg = cfg
   149  
   150  	if opt.Region != "" {
   151  		c.Debug("Using region '%s'", opt.Region)
   152  		c.sessionCfg.Region = opt.Region
   153  	}
   154  	if opt.Endpoint != "" {
   155  		c.Debug("Using endpoint '%s'", opt.Endpoint)
   156  		c.sessionCfg.EndpointResolverWithOptions = createResolver(opt.Endpoint)
   157  	}
   158  
   159  	c.Debug("Discovering caller identity...")
   160  	stsClient := sts.NewFromConfig(c.sessionCfg)
   161  	result, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
   162  	if err != nil {
   163  		return fmt.Errorf("failed to discover AWS caller identity: %w", err)
   164  	}
   165  	if result.Account == nil {
   166  		return fmt.Errorf("missing account id for aws account")
   167  	}
   168  	c.accountID = *result.Account
   169  	c.Debug("AWS account ID: %s", c.accountID)
   170  
   171  	if len(opt.Services) == 0 {
   172  		c.Debug("Preparing to run for all %d registered services...", len(registeredAdapters))
   173  		opt.ProgressTracker.SetTotalServices(len(registeredAdapters))
   174  	} else {
   175  		c.Debug("Preparing to run for %d filtered services...", len(opt.Services))
   176  		opt.ProgressTracker.SetTotalServices(len(opt.Services))
   177  	}
   178  
   179  	c.region = c.sessionCfg.Region
   180  
   181  	var adapterErrors []error
   182  
   183  	for _, adapter := range registeredAdapters {
   184  		if len(opt.Services) != 0 && !contains(opt.Services, adapter.Name()) {
   185  			continue
   186  		}
   187  		c.currentService = adapter.Name()
   188  		c.Debug("Running adapter for %s...", adapter.Name())
   189  		opt.ProgressTracker.StartService(adapter.Name())
   190  
   191  		if err := adapter.Adapt(c, state); err != nil {
   192  			c.Debug("Error occurred while running adapter for %s: %s", adapter.Name(), err)
   193  			adapterErrors = append(adapterErrors, fmt.Errorf("failed to run adapter for %s: %w", adapter.Name(), err))
   194  		}
   195  		opt.ProgressTracker.FinishService()
   196  	}
   197  
   198  	if len(adapterErrors) > 0 {
   199  		return errs.NewAdapterError(adapterErrors)
   200  	}
   201  
   202  	return nil
   203  }
   204  
   205  func contains(services []string, service string) bool {
   206  	for _, s := range services {
   207  		if s == service {
   208  			return true
   209  		}
   210  	}
   211  	return false
   212  }