github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/internal/adapters/cloud/aws/cloudfront/adapt.go (about) 1 package cloudfront 2 3 import ( 4 api "github.com/aws/aws-sdk-go-v2/service/cloudfront" 5 "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" 6 "github.com/khulnasoft-lab/defsec/internal/adapters/cloud/aws" 7 "github.com/khulnasoft-lab/defsec/pkg/concurrency" 8 "github.com/khulnasoft-lab/defsec/pkg/providers/aws/cloudfront" 9 "github.com/khulnasoft-lab/defsec/pkg/state" 10 defsecTypes "github.com/khulnasoft-lab/defsec/pkg/types" 11 ) 12 13 type adapter struct { 14 *aws.RootAdapter 15 client *api.Client 16 } 17 18 func init() { 19 aws.RegisterServiceAdapter(&adapter{}) 20 } 21 22 func (a *adapter) Provider() string { 23 return "aws" 24 } 25 26 func (a *adapter) Name() string { 27 return "cloudfront" 28 } 29 30 func (a *adapter) Adapt(root *aws.RootAdapter, state *state.State) error { 31 32 a.RootAdapter = root 33 a.client = api.NewFromConfig(root.SessionConfig()) 34 var err error 35 36 state.AWS.Cloudfront.Distributions, err = a.getDistributions() 37 if err != nil { 38 return err 39 } 40 41 return nil 42 } 43 44 func (a *adapter) getDistributions() ([]cloudfront.Distribution, error) { 45 46 a.Tracker().SetServiceLabel("Discovering distributions...") 47 48 var apiDistributions []types.DistributionSummary 49 var input api.ListDistributionsInput 50 for { 51 output, err := a.client.ListDistributions(a.Context(), &input) 52 if err != nil { 53 return nil, err 54 } 55 apiDistributions = append(apiDistributions, output.DistributionList.Items...) 56 a.Tracker().SetTotalResources(len(apiDistributions)) 57 if output.DistributionList.NextMarker == nil { 58 break 59 } 60 input.Marker = output.DistributionList.NextMarker 61 } 62 63 a.Tracker().SetServiceLabel("Adapting distributions...") 64 return concurrency.Adapt(apiDistributions, a.RootAdapter, a.adaptDistribution), nil 65 } 66 67 func (a *adapter) adaptDistribution(distribution types.DistributionSummary) (*cloudfront.Distribution, error) { 68 69 metadata := a.CreateMetadataFromARN(*distribution.ARN) 70 71 config, err := a.client.GetDistributionConfig(a.Context(), &api.GetDistributionConfigInput{ 72 Id: distribution.Id, 73 }) 74 if err != nil { 75 return nil, err 76 } 77 78 var wafID string 79 if distribution.WebACLId != nil { 80 wafID = *distribution.WebACLId 81 } 82 83 var loggingBucket string 84 if config.DistributionConfig.Logging != nil && config.DistributionConfig.Logging.Bucket != nil { 85 loggingBucket = *config.DistributionConfig.Logging.Bucket 86 } 87 88 var defaultCacheBehaviour string 89 if config.DistributionConfig.DefaultCacheBehavior != nil { 90 defaultCacheBehaviour = string(config.DistributionConfig.DefaultCacheBehavior.ViewerProtocolPolicy) 91 } 92 93 var cacheBehaviours []cloudfront.CacheBehaviour 94 for _, cacheBehaviour := range config.DistributionConfig.CacheBehaviors.Items { 95 cacheBehaviours = append(cacheBehaviours, cloudfront.CacheBehaviour{ 96 Metadata: metadata, 97 ViewerProtocolPolicy: defsecTypes.String(string(cacheBehaviour.ViewerProtocolPolicy), metadata), 98 }) 99 } 100 101 var minimumProtocolVersion string 102 if config.DistributionConfig.ViewerCertificate != nil { 103 minimumProtocolVersion = string(config.DistributionConfig.ViewerCertificate.MinimumProtocolVersion) 104 } 105 106 return &cloudfront.Distribution{ 107 Metadata: metadata, 108 WAFID: defsecTypes.String(wafID, metadata), 109 Logging: cloudfront.Logging{ 110 Metadata: metadata, 111 Bucket: defsecTypes.String(loggingBucket, metadata), 112 }, 113 DefaultCacheBehaviour: cloudfront.CacheBehaviour{ 114 Metadata: metadata, 115 ViewerProtocolPolicy: defsecTypes.String(defaultCacheBehaviour, metadata), 116 }, 117 OrdererCacheBehaviours: cacheBehaviours, 118 ViewerCertificate: cloudfront.ViewerCertificate{ 119 Metadata: metadata, 120 MinimumProtocolVersion: defsecTypes.String(minimumProtocolVersion, metadata), 121 }, 122 }, nil 123 }