github.com/darmach/terratest@v0.34.8-0.20210517103231-80931f95e3ff/modules/aws/vpc.go (about)

     1  package aws
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"strings"
     7  
     8  	"github.com/aws/aws-sdk-go/aws"
     9  	"github.com/aws/aws-sdk-go/service/ec2"
    10  	"github.com/gruntwork-io/terratest/modules/random"
    11  	"github.com/gruntwork-io/terratest/modules/testing"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  // Vpc is an Amazon Virtual Private Cloud.
    16  type Vpc struct {
    17  	Id      string   // The ID of the VPC
    18  	Name    string   // The name of the VPC
    19  	Subnets []Subnet // A list of subnets in the VPC
    20  }
    21  
    22  // Subnet is a subnet in an availability zone.
    23  type Subnet struct {
    24  	Id               string // The ID of the Subnet
    25  	AvailabilityZone string // The Availability Zone the subnet is in
    26  }
    27  
    28  var vpcIDFilterName = "vpc-id"
    29  var isDefaultFilterName = "isDefault"
    30  var isDefaultFilterValue = "true"
    31  
    32  // GetDefaultVpc fetches information about the default VPC in the given region.
    33  func GetDefaultVpc(t testing.TestingT, region string) *Vpc {
    34  	vpc, err := GetDefaultVpcE(t, region)
    35  	require.NoError(t, err)
    36  	return vpc
    37  }
    38  
    39  // GetDefaultVpcE fetches information about the default VPC in the given region.
    40  func GetDefaultVpcE(t testing.TestingT, region string) (*Vpc, error) {
    41  	defaultVpcFilter := ec2.Filter{Name: &isDefaultFilterName, Values: []*string{&isDefaultFilterValue}}
    42  	vpcs, err := GetVpcsE(t, []*ec2.Filter{&defaultVpcFilter}, region)
    43  
    44  	numVpcs := len(vpcs)
    45  	if numVpcs != 1 {
    46  		return nil, fmt.Errorf("Expected to find one default VPC in region %s but found %s", region, strconv.Itoa(numVpcs))
    47  	}
    48  
    49  	return vpcs[0], err
    50  }
    51  
    52  // GetVpcById fetches information about a VPC with given Id in the given region.
    53  func GetVpcById(t testing.TestingT, vpcId string, region string) *Vpc {
    54  	vpc, err := GetVpcByIdE(t, vpcId, region)
    55  	require.NoError(t, err)
    56  	return vpc
    57  }
    58  
    59  // GetVpcByIdE fetches information about a VPC with given Id in the given region.
    60  func GetVpcByIdE(t testing.TestingT, vpcId string, region string) (*Vpc, error) {
    61  	vpcIdFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcId}}
    62  	vpcs, err := GetVpcsE(t, []*ec2.Filter{&vpcIdFilter}, region)
    63  
    64  	numVpcs := len(vpcs)
    65  	if numVpcs != 1 {
    66  		return nil, fmt.Errorf("Expected to find one VPC with ID %s in region %s but found %s", vpcId, region, strconv.Itoa(numVpcs))
    67  	}
    68  
    69  	return vpcs[0], err
    70  }
    71  
    72  // GetVpcsE fetches informations about VPCs from given regions limited by filters
    73  func GetVpcsE(t testing.TestingT, filters []*ec2.Filter, region string) ([]*Vpc, error) {
    74  	client, err := NewEc2ClientE(t, region)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	vpcs, err := client.DescribeVpcs(&ec2.DescribeVpcsInput{Filters: filters})
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	numVpcs := len(vpcs.Vpcs)
    85  	retVal := make([]*Vpc, numVpcs)
    86  
    87  	for i, vpc := range vpcs.Vpcs {
    88  		subnets, err := GetSubnetsForVpcE(t, aws.StringValue(vpc.VpcId), region)
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  		retVal[i] = &Vpc{Id: aws.StringValue(vpc.VpcId), Name: FindVpcName(vpc), Subnets: subnets}
    93  	}
    94  
    95  	return retVal, nil
    96  }
    97  
    98  // FindVpcName extracts the VPC name from its tags (if any). Fall back to "Default" if it's the default VPC or empty string
    99  // otherwise.
   100  func FindVpcName(vpc *ec2.Vpc) string {
   101  	for _, tag := range vpc.Tags {
   102  		if *tag.Key == "Name" {
   103  			return *tag.Value
   104  		}
   105  	}
   106  
   107  	if *vpc.IsDefault {
   108  		return "Default"
   109  	}
   110  
   111  	return ""
   112  }
   113  
   114  // GetSubnetsForVpc gets the subnets in the specified VPC.
   115  func GetSubnetsForVpc(t testing.TestingT, vpcID string, region string) []Subnet {
   116  	subnets, err := GetSubnetsForVpcE(t, vpcID, region)
   117  	if err != nil {
   118  		t.Fatal(err)
   119  	}
   120  	return subnets
   121  }
   122  
   123  // GetSubnetsForVpcE gets the subnets in the specified VPC.
   124  func GetSubnetsForVpcE(t testing.TestingT, vpcID string, region string) ([]Subnet, error) {
   125  	client, err := NewEc2ClientE(t, region)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	vpcIDFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcID}}
   131  	subnetOutput, err := client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{&vpcIDFilter}})
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	subnets := []Subnet{}
   137  
   138  	for _, ec2Subnet := range subnetOutput.Subnets {
   139  		subnet := Subnet{Id: aws.StringValue(ec2Subnet.SubnetId), AvailabilityZone: aws.StringValue(ec2Subnet.AvailabilityZone)}
   140  		subnets = append(subnets, subnet)
   141  	}
   142  
   143  	return subnets, nil
   144  }
   145  
   146  // IsPublicSubnet returns True if the subnet identified by the given id in the provided region is public.
   147  func IsPublicSubnet(t testing.TestingT, subnetId string, region string) bool {
   148  	isPublic, err := IsPublicSubnetE(t, subnetId, region)
   149  	require.NoError(t, err)
   150  	return isPublic
   151  }
   152  
   153  // IsPublicSubnetE returns True if the subnet identified by the given id in the provided region is public.
   154  func IsPublicSubnetE(t testing.TestingT, subnetId string, region string) (bool, error) {
   155  	subnetIdFilterName := "association.subnet-id"
   156  
   157  	subnetIdFilter := ec2.Filter{
   158  		Name:   &subnetIdFilterName,
   159  		Values: []*string{&subnetId},
   160  	}
   161  
   162  	client, err := NewEc2ClientE(t, region)
   163  	if err != nil {
   164  		return false, err
   165  	}
   166  
   167  	rts, err := client.DescribeRouteTables(&ec2.DescribeRouteTablesInput{Filters: []*ec2.Filter{&subnetIdFilter}})
   168  	if err != nil {
   169  		return false, err
   170  	}
   171  
   172  	for _, rt := range rts.RouteTables {
   173  		for _, r := range rt.Routes {
   174  			if strings.HasPrefix(aws.StringValue(r.GatewayId), "igw-") {
   175  				return true, nil
   176  			}
   177  		}
   178  	}
   179  
   180  	return false, nil
   181  }
   182  
   183  // GetRandomPrivateCidrBlock gets a random CIDR block from the range of acceptable private IP addresses per RFC 1918
   184  // (https://tools.ietf.org/html/rfc1918#section-3)
   185  // The routingPrefix refers to the "/28" in 1.2.3.4/28.
   186  // Note that, as written, this function will return a subset of all valid ranges. Since we will probably use this function
   187  // mostly for generating random CIDR ranges for VPCs and Subnets, having comprehensive set coverage is not essential.
   188  func GetRandomPrivateCidrBlock(routingPrefix int) string {
   189  
   190  	var o1, o2, o3, o4 int
   191  
   192  	switch routingPrefix {
   193  	case 32:
   194  		o1 = random.RandomInt([]int{10, 172, 192})
   195  
   196  		switch o1 {
   197  		case 10:
   198  			o2 = random.Random(0, 255)
   199  			o3 = random.Random(0, 255)
   200  			o4 = random.Random(0, 255)
   201  		case 172:
   202  			o2 = random.Random(16, 31)
   203  			o3 = random.Random(0, 255)
   204  			o4 = random.Random(0, 255)
   205  		case 192:
   206  			o2 = 168
   207  			o3 = random.Random(0, 255)
   208  			o4 = random.Random(0, 255)
   209  		}
   210  
   211  	case 31, 30, 29, 28, 27, 26, 25:
   212  		fallthrough
   213  	case 24:
   214  		o1 = random.RandomInt([]int{10, 172, 192})
   215  
   216  		switch o1 {
   217  		case 10:
   218  			o2 = random.Random(0, 255)
   219  			o3 = random.Random(0, 255)
   220  			o4 = 0
   221  		case 172:
   222  			o2 = 16
   223  			o3 = 0
   224  			o4 = 0
   225  		case 192:
   226  			o2 = 168
   227  			o3 = 0
   228  			o4 = 0
   229  		}
   230  	case 23, 22, 21, 20, 19:
   231  		fallthrough
   232  	case 18:
   233  		o1 = random.RandomInt([]int{10, 172, 192})
   234  
   235  		switch o1 {
   236  		case 10:
   237  			o2 = 0
   238  			o3 = 0
   239  			o4 = 0
   240  		case 172:
   241  			o2 = 16
   242  			o3 = 0
   243  			o4 = 0
   244  		case 192:
   245  			o2 = 168
   246  			o3 = 0
   247  			o4 = 0
   248  		}
   249  	}
   250  	return fmt.Sprintf("%d.%d.%d.%d/%d", o1, o2, o3, o4, routingPrefix)
   251  }
   252  
   253  // GetFirstTwoOctets gets the first two octets from a CIDR block.
   254  func GetFirstTwoOctets(cidrBlock string) string {
   255  	ipAddr := strings.Split(cidrBlock, "/")[0]
   256  	octets := strings.Split(ipAddr, ".")
   257  	return octets[0] + "." + octets[1]
   258  }