github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/internal/adapters/cloud/aws/cloudtrail/adapt.go (about)

     1  package cloudtrail
     2  
     3  import (
     4  	api "github.com/aws/aws-sdk-go-v2/service/cloudtrail"
     5  	"github.com/aws/aws-sdk-go-v2/service/cloudtrail/types"
     6  	"github.com/khulnasoft-lab/defsec/internal/adapters/cloud/aws"
     7  	"github.com/khulnasoft-lab/defsec/pkg/concurrency"
     8  	"github.com/khulnasoft-lab/defsec/pkg/providers/aws/cloudtrail"
     9  	"github.com/khulnasoft-lab/defsec/pkg/state"
    10  	defsecTypes "github.com/khulnasoft-lab/defsec/pkg/types"
    11  )
    12  
    13  type adapter struct {
    14  	*aws.RootAdapter
    15  	client *api.Client
    16  }
    17  
    18  func init() {
    19  	aws.RegisterServiceAdapter(&adapter{})
    20  }
    21  
    22  func (a *adapter) Provider() string {
    23  	return "aws"
    24  }
    25  
    26  func (a *adapter) Name() string {
    27  	return "cloudtrail"
    28  }
    29  
    30  func (a *adapter) Adapt(root *aws.RootAdapter, state *state.State) error {
    31  
    32  	a.RootAdapter = root
    33  	a.client = api.NewFromConfig(root.SessionConfig())
    34  	var err error
    35  
    36  	state.AWS.CloudTrail.Trails, err = a.getTrails()
    37  	if err != nil {
    38  		return err
    39  	}
    40  
    41  	return nil
    42  }
    43  
    44  func (a *adapter) getTrails() ([]cloudtrail.Trail, error) {
    45  
    46  	a.Tracker().SetServiceLabel("Discovering trails...")
    47  
    48  	var apiTrails []types.TrailInfo
    49  	var input api.ListTrailsInput
    50  	for {
    51  		output, err := a.client.ListTrails(a.Context(), &input)
    52  		if err != nil {
    53  			return nil, err
    54  		}
    55  		apiTrails = append(apiTrails, output.Trails...)
    56  		a.Tracker().SetTotalResources(len(apiTrails))
    57  		if output.NextToken == nil {
    58  			break
    59  		}
    60  		input.NextToken = output.NextToken
    61  	}
    62  
    63  	a.Tracker().SetServiceLabel("Adapting trails...")
    64  	return concurrency.Adapt(apiTrails, a.RootAdapter, a.adaptTrail), nil
    65  }
    66  
    67  func (a *adapter) adaptTrail(info types.TrailInfo) (*cloudtrail.Trail, error) {
    68  
    69  	metadata := a.CreateMetadataFromARN(*info.TrailARN)
    70  
    71  	response, err := a.client.GetTrail(a.Context(), &api.GetTrailInput{
    72  		Name: info.TrailARN,
    73  	})
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	var kmsKeyId string
    79  	if response.Trail.KmsKeyId != nil {
    80  		kmsKeyId = *response.Trail.KmsKeyId
    81  	}
    82  
    83  	status, err := a.client.GetTrailStatus(a.Context(), &api.GetTrailStatusInput{
    84  		Name: response.Trail.Name,
    85  	})
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	cloudWatchLogsArn := defsecTypes.StringDefault("", metadata)
    91  	if response.Trail.CloudWatchLogsLogGroupArn != nil {
    92  		cloudWatchLogsArn = defsecTypes.String(*response.Trail.CloudWatchLogsLogGroupArn, metadata)
    93  	}
    94  
    95  	var bucketName string
    96  	if response.Trail.S3BucketName != nil {
    97  		bucketName = *response.Trail.S3BucketName
    98  	}
    99  
   100  	name := defsecTypes.StringDefault("", metadata)
   101  	if info.Name != nil {
   102  		name = defsecTypes.String(*info.Name, metadata)
   103  	}
   104  
   105  	isLogging := defsecTypes.BoolDefault(false, metadata)
   106  	if status.IsLogging != nil {
   107  		isLogging = defsecTypes.Bool(*status.IsLogging, metadata)
   108  	}
   109  
   110  	var eventSelectors []cloudtrail.EventSelector
   111  	if response.Trail.HasCustomEventSelectors != nil && *response.Trail.HasCustomEventSelectors {
   112  		output, err := a.client.GetEventSelectors(a.Context(), &api.GetEventSelectorsInput{
   113  			TrailName: info.Name,
   114  		})
   115  		if err != nil {
   116  			return nil, err
   117  		}
   118  		for _, eventSelector := range output.EventSelectors {
   119  			var resources []cloudtrail.DataResource
   120  			for _, dataResource := range eventSelector.DataResources {
   121  				typ := defsecTypes.StringDefault("", metadata)
   122  				if dataResource.Type != nil {
   123  					typ = defsecTypes.String(*dataResource.Type, metadata)
   124  				}
   125  				var values defsecTypes.StringValueList
   126  				for _, value := range dataResource.Values {
   127  					values = append(values, defsecTypes.String(value, metadata))
   128  				}
   129  				resources = append(resources, cloudtrail.DataResource{
   130  					Metadata: metadata,
   131  					Type:     typ,
   132  					Values:   values,
   133  				})
   134  			}
   135  			eventSelectors = append(eventSelectors, cloudtrail.EventSelector{
   136  				Metadata:      metadata,
   137  				DataResources: resources,
   138  				ReadWriteType: defsecTypes.String(string(eventSelector.ReadWriteType), metadata),
   139  			})
   140  		}
   141  	}
   142  
   143  	return &cloudtrail.Trail{
   144  		Metadata:                  metadata,
   145  		Name:                      name,
   146  		EnableLogFileValidation:   defsecTypes.Bool(response.Trail.LogFileValidationEnabled != nil && *response.Trail.LogFileValidationEnabled, metadata),
   147  		IsMultiRegion:             defsecTypes.Bool(response.Trail.IsMultiRegionTrail != nil && *response.Trail.IsMultiRegionTrail, metadata),
   148  		CloudWatchLogsLogGroupArn: cloudWatchLogsArn,
   149  		KMSKeyID:                  defsecTypes.String(kmsKeyId, metadata),
   150  		IsLogging:                 isLogging,
   151  		BucketName:                defsecTypes.String(bucketName, metadata),
   152  		EventSelectors:            eventSelectors,
   153  	}, nil
   154  }