github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/internal/adapters/cloud/aws/ec2/vpc_test.go (about)

     1  package ec2
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/aws/aws-sdk-go-v2/aws"
     7  	vpcApi "github.com/aws/aws-sdk-go-v2/service/ec2"
     8  	vpcTypes "github.com/aws/aws-sdk-go-v2/service/ec2/types"
     9  	aws2 "github.com/khulnasoft-lab/defsec/internal/adapters/cloud/aws"
    10  	"github.com/khulnasoft-lab/defsec/internal/adapters/cloud/aws/test"
    11  	"github.com/khulnasoft-lab/defsec/pkg/providers/aws/ec2"
    12  	"github.com/khulnasoft-lab/defsec/pkg/state"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  type rule struct {
    17  	egress     bool
    18  	protocol   string
    19  	ruleAction vpcTypes.RuleAction
    20  	cidrBlock  string
    21  	fromPort   int
    22  	toPort     int
    23  }
    24  
    25  type nacl struct {
    26  	naclRules []rule
    27  }
    28  
    29  type sg struct {
    30  	name        string
    31  	description string
    32  	sgRules     []rule
    33  }
    34  
    35  type vpcDetails struct {
    36  	nacl            *nacl
    37  	securityGroup   *sg
    38  	flowLogsEnabled bool
    39  }
    40  
    41  func Test_VPCNetworkACLs(t *testing.T) {
    42  
    43  	tests := []struct {
    44  		name    string
    45  		details vpcDetails
    46  	}{
    47  		{
    48  			name: "simple nacl",
    49  			details: vpcDetails{
    50  				nacl: &nacl{
    51  					naclRules: []rule{
    52  						{
    53  							egress:     true,
    54  							protocol:   "tcp",
    55  							cidrBlock:  "10.0.0.0/24",
    56  							ruleAction: vpcTypes.RuleActionDeny,
    57  							fromPort:   80,
    58  							toPort:     80,
    59  						},
    60  					},
    61  				},
    62  			},
    63  		},
    64  	}
    65  
    66  	ra, stack, err := test.CreateLocalstackAdapter(t)
    67  	defer func() { _ = stack.Stop() }()
    68  	require.NoError(t, err)
    69  
    70  	for _, tt := range tests {
    71  		t.Run(tt.name, func(t *testing.T) {
    72  			bootstrapVPC(t, ra, tt.details)
    73  
    74  			testState := &state.State{}
    75  			adapter := &adapter{}
    76  			err = adapter.Adapt(ra, testState)
    77  			require.NoError(t, err)
    78  
    79  			require.NotNil(t, testState.AWS.EC2)
    80  			require.Len(t, testState.AWS.EC2.NetworkACLs, 3)
    81  
    82  			var aclFound bool
    83  
    84  			for _, a := range testState.AWS.EC2.NetworkACLs {
    85  				if !a.IsDefaultRule.Value() {
    86  					aclFound = true
    87  					break
    88  				}
    89  			}
    90  
    91  			require.True(t, aclFound)
    92  
    93  		})
    94  	}
    95  }
    96  
    97  func Test_VPCFlowLogs(t *testing.T) {
    98  
    99  	tests := []struct {
   100  		name    string
   101  		details vpcDetails
   102  	}{
   103  		{
   104  			name: "simple flow logs",
   105  			details: vpcDetails{
   106  				flowLogsEnabled: true,
   107  			},
   108  		},
   109  	}
   110  
   111  	ra, stack, err := test.CreateLocalstackAdapter(t)
   112  	defer func() { _ = stack.Stop() }()
   113  	require.NoError(t, err)
   114  
   115  	for _, tt := range tests {
   116  		t.Run(tt.name, func(t *testing.T) {
   117  			vpcId := bootstrapVPC(t, ra, tt.details)
   118  
   119  			testState := &state.State{}
   120  			adapter := &adapter{}
   121  			err = adapter.Adapt(ra, testState)
   122  			require.NoError(t, err)
   123  
   124  			require.NotNil(t, testState.AWS.EC2)
   125  			var testVPCs []ec2.VPC
   126  			for _, v := range testState.AWS.EC2.VPCs {
   127  				if v.IsDefault.IsFalse() {
   128  					testVPCs = append(testVPCs, v)
   129  				}
   130  			}
   131  
   132  			require.Len(t, testVPCs, 1)
   133  			vpc := testVPCs[0]
   134  			require.Equal(t, tt.details.flowLogsEnabled, vpc.FlowLogsEnabled.Value())
   135  
   136  			destroyVPC(t, ra, vpcId)
   137  
   138  		})
   139  	}
   140  }
   141  
   142  func Test_VPCSecurityGroups(t *testing.T) {
   143  
   144  	tests := []struct {
   145  		name    string
   146  		details vpcDetails
   147  	}{
   148  		{
   149  			name: "simple security group",
   150  			details: vpcDetails{
   151  				securityGroup: &sg{
   152  					name:        "test-sg",
   153  					description: "a test security group description",
   154  					sgRules: []rule{
   155  						{
   156  							egress:    true,
   157  							protocol:  "tcp",
   158  							cidrBlock: "10.0.0.0/24",
   159  							fromPort:  80,
   160  							toPort:    80,
   161  						},
   162  					},
   163  				},
   164  			},
   165  		},
   166  	}
   167  
   168  	ra, stack, err := test.CreateLocalstackAdapter(t)
   169  	defer func() { _ = stack.Stop() }()
   170  	require.NoError(t, err)
   171  
   172  	for _, tt := range tests {
   173  		t.Run(tt.name, func(t *testing.T) {
   174  			vpcId := bootstrapVPC(t, ra, tt.details)
   175  
   176  			testState := &state.State{}
   177  			adapter := &adapter{}
   178  			err = adapter.Adapt(ra, testState)
   179  			require.NoError(t, err)
   180  
   181  			require.NotNil(t, testState.AWS.EC2)
   182  			require.Len(t, testState.AWS.EC2.SecurityGroups, 3)
   183  
   184  			sg := testState.AWS.EC2.SecurityGroups[0]
   185  			require.NotNil(t, sg)
   186  
   187  			destroyVPC(t, ra, vpcId)
   188  
   189  		})
   190  	}
   191  }
   192  
   193  func destroyVPC(t *testing.T, ra *aws2.RootAdapter, id *string) {
   194  	api := vpcApi.NewFromConfig(ra.SessionConfig())
   195  
   196  	_, err := api.DeleteVpc(ra.Context(), &vpcApi.DeleteVpcInput{
   197  		VpcId: id,
   198  	})
   199  
   200  	require.NoError(t, err)
   201  }
   202  
   203  func bootstrapVPC(t *testing.T, ra *aws2.RootAdapter, spec vpcDetails) *string {
   204  
   205  	api := vpcApi.NewFromConfig(ra.SessionConfig())
   206  
   207  	vpc, err := api.CreateVpc(ra.Context(), &vpcApi.CreateVpcInput{
   208  		CidrBlock: aws.String("10.0.0.0/16"),
   209  	})
   210  
   211  	require.NoError(t, err)
   212  
   213  	if spec.nacl != nil {
   214  		addNacl(t, ra, spec, api, vpc)
   215  	}
   216  
   217  	if spec.securityGroup != nil {
   218  		addSecurityGroup(t, ra, spec, api, vpc)
   219  	}
   220  
   221  	if spec.flowLogsEnabled {
   222  		addFlowLogs(t, ra, api, vpc)
   223  	}
   224  
   225  	return vpc.Vpc.VpcId
   226  }
   227  
   228  func addFlowLogs(t *testing.T, ra *aws2.RootAdapter, api *vpcApi.Client, vpc *vpcApi.CreateVpcOutput) {
   229  	logs, err := api.CreateFlowLogs(ra.Context(), &vpcApi.CreateFlowLogsInput{
   230  		ResourceIds:        []string{*vpc.Vpc.VpcId},
   231  		ResourceType:       vpcTypes.FlowLogsResourceTypeVpc,
   232  		LogDestinationType: vpcTypes.LogDestinationTypeS3,
   233  		LogDestination:     aws.String("arn:aws:s3:::access-logs"),
   234  	})
   235  
   236  	require.NoError(t, err)
   237  	require.NotNil(t, logs)
   238  }
   239  
   240  func addNacl(t *testing.T, ra *aws2.RootAdapter, spec vpcDetails, api *vpcApi.Client, vpc *vpcApi.CreateVpcOutput) {
   241  	acl, err := api.CreateNetworkAcl(ra.Context(), &vpcApi.CreateNetworkAclInput{
   242  		VpcId: vpc.Vpc.VpcId,
   243  	})
   244  	require.NoError(t, err)
   245  
   246  	for i, rule := range spec.nacl.naclRules {
   247  		_, err = api.CreateNetworkAclEntry(ra.Context(), &vpcApi.CreateNetworkAclEntryInput{
   248  			NetworkAclId: acl.NetworkAcl.NetworkAclId,
   249  			Egress:       aws.Bool(rule.egress),
   250  			RuleAction:   rule.ruleAction,
   251  			RuleNumber:   aws.Int32(int32(i)),
   252  			Protocol:     aws.String(rule.protocol),
   253  			CidrBlock:    aws.String(rule.cidrBlock),
   254  			PortRange: &vpcTypes.PortRange{
   255  				From: aws.Int32(80),
   256  				To:   aws.Int32(80),
   257  			},
   258  		})
   259  		require.NoError(t, err)
   260  	}
   261  }
   262  
   263  func addSecurityGroup(t *testing.T, ra *aws2.RootAdapter, spec vpcDetails, api *vpcApi.Client, vpc *vpcApi.CreateVpcOutput) {
   264  	_, err := api.CreateSecurityGroup(ra.Context(), &vpcApi.CreateSecurityGroupInput{
   265  		VpcId:       vpc.Vpc.VpcId,
   266  		GroupName:   aws.String(spec.securityGroup.name),
   267  		Description: aws.String(spec.securityGroup.description),
   268  	})
   269  	require.NoError(t, err)
   270  }