github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/internal/adapters/cloud/aws/cloudfront/adapt.go (about)

     1  package cloudfront
     2  
     3  import (
     4  	api "github.com/aws/aws-sdk-go-v2/service/cloudfront"
     5  	"github.com/aws/aws-sdk-go-v2/service/cloudfront/types"
     6  	"github.com/khulnasoft-lab/defsec/internal/adapters/cloud/aws"
     7  	"github.com/khulnasoft-lab/defsec/pkg/concurrency"
     8  	"github.com/khulnasoft-lab/defsec/pkg/providers/aws/cloudfront"
     9  	"github.com/khulnasoft-lab/defsec/pkg/state"
    10  	defsecTypes "github.com/khulnasoft-lab/defsec/pkg/types"
    11  )
    12  
    13  type adapter struct {
    14  	*aws.RootAdapter
    15  	client *api.Client
    16  }
    17  
    18  func init() {
    19  	aws.RegisterServiceAdapter(&adapter{})
    20  }
    21  
    22  func (a *adapter) Provider() string {
    23  	return "aws"
    24  }
    25  
    26  func (a *adapter) Name() string {
    27  	return "cloudfront"
    28  }
    29  
    30  func (a *adapter) Adapt(root *aws.RootAdapter, state *state.State) error {
    31  
    32  	a.RootAdapter = root
    33  	a.client = api.NewFromConfig(root.SessionConfig())
    34  	var err error
    35  
    36  	state.AWS.Cloudfront.Distributions, err = a.getDistributions()
    37  	if err != nil {
    38  		return err
    39  	}
    40  
    41  	return nil
    42  }
    43  
    44  func (a *adapter) getDistributions() ([]cloudfront.Distribution, error) {
    45  
    46  	a.Tracker().SetServiceLabel("Discovering distributions...")
    47  
    48  	var apiDistributions []types.DistributionSummary
    49  	var input api.ListDistributionsInput
    50  	for {
    51  		output, err := a.client.ListDistributions(a.Context(), &input)
    52  		if err != nil {
    53  			return nil, err
    54  		}
    55  		apiDistributions = append(apiDistributions, output.DistributionList.Items...)
    56  		a.Tracker().SetTotalResources(len(apiDistributions))
    57  		if output.DistributionList.NextMarker == nil {
    58  			break
    59  		}
    60  		input.Marker = output.DistributionList.NextMarker
    61  	}
    62  
    63  	a.Tracker().SetServiceLabel("Adapting distributions...")
    64  	return concurrency.Adapt(apiDistributions, a.RootAdapter, a.adaptDistribution), nil
    65  }
    66  
    67  func (a *adapter) adaptDistribution(distribution types.DistributionSummary) (*cloudfront.Distribution, error) {
    68  
    69  	metadata := a.CreateMetadataFromARN(*distribution.ARN)
    70  
    71  	config, err := a.client.GetDistributionConfig(a.Context(), &api.GetDistributionConfigInput{
    72  		Id: distribution.Id,
    73  	})
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	var wafID string
    79  	if distribution.WebACLId != nil {
    80  		wafID = *distribution.WebACLId
    81  	}
    82  
    83  	var loggingBucket string
    84  	if config.DistributionConfig.Logging != nil && config.DistributionConfig.Logging.Bucket != nil {
    85  		loggingBucket = *config.DistributionConfig.Logging.Bucket
    86  	}
    87  
    88  	var defaultCacheBehaviour string
    89  	if config.DistributionConfig.DefaultCacheBehavior != nil {
    90  		defaultCacheBehaviour = string(config.DistributionConfig.DefaultCacheBehavior.ViewerProtocolPolicy)
    91  	}
    92  
    93  	var cacheBehaviours []cloudfront.CacheBehaviour
    94  	for _, cacheBehaviour := range config.DistributionConfig.CacheBehaviors.Items {
    95  		cacheBehaviours = append(cacheBehaviours, cloudfront.CacheBehaviour{
    96  			Metadata:             metadata,
    97  			ViewerProtocolPolicy: defsecTypes.String(string(cacheBehaviour.ViewerProtocolPolicy), metadata),
    98  		})
    99  	}
   100  
   101  	var minimumProtocolVersion string
   102  	if config.DistributionConfig.ViewerCertificate != nil {
   103  		minimumProtocolVersion = string(config.DistributionConfig.ViewerCertificate.MinimumProtocolVersion)
   104  	}
   105  
   106  	return &cloudfront.Distribution{
   107  		Metadata: metadata,
   108  		WAFID:    defsecTypes.String(wafID, metadata),
   109  		Logging: cloudfront.Logging{
   110  			Metadata: metadata,
   111  			Bucket:   defsecTypes.String(loggingBucket, metadata),
   112  		},
   113  		DefaultCacheBehaviour: cloudfront.CacheBehaviour{
   114  			Metadata:             metadata,
   115  			ViewerProtocolPolicy: defsecTypes.String(defaultCacheBehaviour, metadata),
   116  		},
   117  		OrdererCacheBehaviours: cacheBehaviours,
   118  		ViewerCertificate: cloudfront.ViewerCertificate{
   119  			Metadata:               metadata,
   120  			MinimumProtocolVersion: defsecTypes.String(minimumProtocolVersion, metadata),
   121  		},
   122  	}, nil
   123  }