github.com/mponton/terratest@v0.44.0/modules/aws/vpc.go (about)

     1  package aws
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"strings"
     7  
     8  	"github.com/aws/aws-sdk-go/aws"
     9  	"github.com/aws/aws-sdk-go/service/ec2"
    10  	"github.com/mponton/terratest/modules/random"
    11  	"github.com/mponton/terratest/modules/testing"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  // Vpc is an Amazon Virtual Private Cloud.
    16  type Vpc struct {
    17  	Id      string            // The ID of the VPC
    18  	Name    string            // The name of the VPC
    19  	Subnets []Subnet          // A list of subnets in the VPC
    20  	Tags    map[string]string // The tags associated with the VPC
    21  }
    22  
    23  // Subnet is a subnet in an availability zone.
    24  type Subnet struct {
    25  	Id               string            // The ID of the Subnet
    26  	AvailabilityZone string            // The Availability Zone the subnet is in
    27  	DefaultForAz     bool              // If the subnet is default for the Availability Zone
    28  	Tags             map[string]string // The tags associated with the subnet
    29  }
    30  
    31  const vpcIDFilterName = "vpc-id"
    32  const defaultForAzFilterName = "default-for-az"
    33  const resourceTypeFilterName = "resource-id"
    34  const resourceIdFilterName = "resource-type"
    35  const vpcResourceTypeFilterValue = "vpc"
    36  const subnetResourceTypeFilterValue = "subnet"
    37  const isDefaultFilterName = "isDefault"
    38  const isDefaultFilterValue = "true"
    39  const defaultVPCName = "Default"
    40  
    41  // GetDefaultVpc fetches information about the default VPC in the given region.
    42  func GetDefaultVpc(t testing.TestingT, region string) *Vpc {
    43  	vpc, err := GetDefaultVpcE(t, region)
    44  	require.NoError(t, err)
    45  	return vpc
    46  }
    47  
    48  // GetDefaultVpcE fetches information about the default VPC in the given region.
    49  func GetDefaultVpcE(t testing.TestingT, region string) (*Vpc, error) {
    50  	defaultVpcFilter := ec2.Filter{Name: aws.String(isDefaultFilterName), Values: []*string{aws.String(isDefaultFilterValue)}}
    51  	vpcs, err := GetVpcsE(t, []*ec2.Filter{&defaultVpcFilter}, region)
    52  
    53  	numVpcs := len(vpcs)
    54  	if numVpcs != 1 {
    55  		return nil, fmt.Errorf("Expected to find one default VPC in region %s but found %s", region, strconv.Itoa(numVpcs))
    56  	}
    57  
    58  	return vpcs[0], err
    59  }
    60  
    61  // GetVpcById fetches information about a VPC with given Id in the given region.
    62  func GetVpcById(t testing.TestingT, vpcId string, region string) *Vpc {
    63  	vpc, err := GetVpcByIdE(t, vpcId, region)
    64  	require.NoError(t, err)
    65  	return vpc
    66  }
    67  
    68  // GetVpcByIdE fetches information about a VPC with given Id in the given region.
    69  func GetVpcByIdE(t testing.TestingT, vpcId string, region string) (*Vpc, error) {
    70  	vpcIdFilter := ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcId}}
    71  	vpcs, err := GetVpcsE(t, []*ec2.Filter{&vpcIdFilter}, region)
    72  
    73  	numVpcs := len(vpcs)
    74  	if numVpcs != 1 {
    75  		return nil, fmt.Errorf("Expected to find one VPC with ID %s in region %s but found %s", vpcId, region, strconv.Itoa(numVpcs))
    76  	}
    77  
    78  	return vpcs[0], err
    79  }
    80  
    81  // GetVpcsE fetches informations about VPCs from given regions limited by filters
    82  func GetVpcsE(t testing.TestingT, filters []*ec2.Filter, region string) ([]*Vpc, error) {
    83  	client, err := NewEc2ClientE(t, region)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	vpcs, err := client.DescribeVpcs(&ec2.DescribeVpcsInput{Filters: filters})
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	numVpcs := len(vpcs.Vpcs)
    94  	retVal := make([]*Vpc, numVpcs)
    95  
    96  	for i, vpc := range vpcs.Vpcs {
    97  		vpcIdFilter := generateVpcIdFilter(aws.StringValue(vpc.VpcId))
    98  		subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIdFilter})
    99  		if err != nil {
   100  			return nil, err
   101  		}
   102  
   103  		tags, err := GetTagsForVpcE(t, aws.StringValue(vpc.VpcId), region)
   104  		if err != nil {
   105  			return nil, err
   106  		}
   107  
   108  		retVal[i] = &Vpc{Id: aws.StringValue(vpc.VpcId), Name: FindVpcName(vpc), Subnets: subnets, Tags: tags}
   109  	}
   110  
   111  	return retVal, nil
   112  }
   113  
   114  // FindVpcName extracts the VPC name from its tags (if any). Fall back to "Default" if it's the default VPC or empty string
   115  // otherwise.
   116  func FindVpcName(vpc *ec2.Vpc) string {
   117  	for _, tag := range vpc.Tags {
   118  		if *tag.Key == "Name" {
   119  			return *tag.Value
   120  		}
   121  	}
   122  
   123  	if *vpc.IsDefault {
   124  		return defaultVPCName
   125  	}
   126  
   127  	return ""
   128  }
   129  
   130  // GetSubnetsForVpc gets the subnets in the specified VPC.
   131  func GetSubnetsForVpc(t testing.TestingT, vpcID string, region string) []Subnet {
   132  	vpcIDFilter := generateVpcIdFilter(vpcID)
   133  	subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIDFilter})
   134  	if err != nil {
   135  		t.Fatal(err)
   136  	}
   137  	return subnets
   138  }
   139  
   140  // GetAzDefaultSubnetsForVpc gets the default az subnets in the specified VPC.
   141  func GetAzDefaultSubnetsForVpc(t testing.TestingT, vpcID string, region string) []Subnet {
   142  	vpcIDFilter := generateVpcIdFilter(vpcID)
   143  	defaultForAzFilter := ec2.Filter{
   144  		Name:   aws.String(defaultForAzFilterName),
   145  		Values: []*string{aws.String("true")},
   146  	}
   147  	subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIDFilter, &defaultForAzFilter})
   148  	if err != nil {
   149  		t.Fatal(err)
   150  	}
   151  	return subnets
   152  }
   153  
   154  // generateVpcIdFilter is a helper method to generate vpc id filter
   155  func generateVpcIdFilter(vpcID string) ec2.Filter {
   156  	return ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcID}}
   157  }
   158  
   159  // GetSubnetsForVpcE gets the subnets in the specified VPC.
   160  func GetSubnetsForVpcE(t testing.TestingT, region string, filters []*ec2.Filter) ([]Subnet, error) {
   161  	client, err := NewEc2ClientE(t, region)
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  
   166  	subnetOutput, err := client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: filters})
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	subnets := []Subnet{}
   172  
   173  	for _, ec2Subnet := range subnetOutput.Subnets {
   174  		subnetTags := GetTagsForSubnet(t, *ec2Subnet.SubnetId, region)
   175  		subnet := Subnet{Id: aws.StringValue(ec2Subnet.SubnetId), AvailabilityZone: aws.StringValue(ec2Subnet.AvailabilityZone), DefaultForAz: aws.BoolValue(ec2Subnet.DefaultForAz), Tags: subnetTags}
   176  		subnets = append(subnets, subnet)
   177  	}
   178  
   179  	return subnets, nil
   180  }
   181  
   182  // GetTagsForVpc gets the tags for the specified VPC.
   183  func GetTagsForVpc(t testing.TestingT, vpcID string, region string) map[string]string {
   184  	tags, err := GetTagsForVpcE(t, vpcID, region)
   185  	require.NoError(t, err)
   186  
   187  	return tags
   188  }
   189  
   190  // GetTagsForVpcE gets the tags for the specified VPC.
   191  func GetTagsForVpcE(t testing.TestingT, vpcID string, region string) (map[string]string, error) {
   192  	client, err := NewEc2ClientE(t, region)
   193  	require.NoError(t, err)
   194  
   195  	vpcResourceTypeFilter := ec2.Filter{Name: aws.String(resourceIdFilterName), Values: []*string{aws.String(vpcResourceTypeFilterValue)}}
   196  	vpcResourceIdFilter := ec2.Filter{Name: aws.String(resourceTypeFilterName), Values: []*string{&vpcID}}
   197  	tagsOutput, err := client.DescribeTags(&ec2.DescribeTagsInput{Filters: []*ec2.Filter{&vpcResourceTypeFilter, &vpcResourceIdFilter}})
   198  	require.NoError(t, err)
   199  
   200  	tags := map[string]string{}
   201  	for _, tag := range tagsOutput.Tags {
   202  		tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value)
   203  	}
   204  
   205  	return tags, nil
   206  }
   207  
   208  // GetDefaultSubnetIDsForVpc gets the ids of the subnets that are the default subnet for the AvailabilityZone
   209  func GetDefaultSubnetIDsForVpc(t testing.TestingT, vpc Vpc) []string {
   210  	subnetIDs, err := GetDefaultSubnetIDsForVpcE(t, vpc)
   211  	require.NoError(t, err)
   212  	return subnetIDs
   213  }
   214  
   215  // GetDefaultSubnetIDsForVpcE gets the ids of the subnets that are the default subnet for the AvailabilityZone
   216  func GetDefaultSubnetIDsForVpcE(t testing.TestingT, vpc Vpc) ([]string, error) {
   217  	if vpc.Name != defaultVPCName {
   218  		// You cannot create a default subnet in a nondefault VPC
   219  		// https://docs.aws.amazon.com/vpc/latest/userguide/default-vpc.html
   220  		return nil, fmt.Errorf("Only default VPCs have default subnets but VPC with id %s is not default VPC", vpc.Id)
   221  	}
   222  	subnetIDs := []string{}
   223  	numSubnets := len(vpc.Subnets)
   224  	if numSubnets == 0 {
   225  		return nil, fmt.Errorf("Expected to find at least one subnet in vpc with ID %s but found zero", vpc.Id)
   226  	}
   227  
   228  	for _, subnet := range vpc.Subnets {
   229  		if subnet.DefaultForAz {
   230  			subnetIDs = append(subnetIDs, subnet.Id)
   231  		}
   232  	}
   233  	return subnetIDs, nil
   234  }
   235  
   236  // GetTagsForSubnet gets the tags for the specified subnet.
   237  func GetTagsForSubnet(t testing.TestingT, subnetId string, region string) map[string]string {
   238  	tags, err := GetTagsForSubnetE(t, subnetId, region)
   239  	require.NoError(t, err)
   240  
   241  	return tags
   242  }
   243  
   244  // GetTagsForSubnetE gets the tags for the specified subnet.
   245  func GetTagsForSubnetE(t testing.TestingT, subnetId string, region string) (map[string]string, error) {
   246  	client, err := NewEc2ClientE(t, region)
   247  	require.NoError(t, err)
   248  
   249  	subnetResourceTypeFilter := ec2.Filter{Name: aws.String(resourceIdFilterName), Values: []*string{aws.String(subnetResourceTypeFilterValue)}}
   250  	subnetResourceIdFilter := ec2.Filter{Name: aws.String(resourceTypeFilterName), Values: []*string{&subnetId}}
   251  	tagsOutput, err := client.DescribeTags(&ec2.DescribeTagsInput{Filters: []*ec2.Filter{&subnetResourceTypeFilter, &subnetResourceIdFilter}})
   252  	require.NoError(t, err)
   253  
   254  	tags := map[string]string{}
   255  	for _, tag := range tagsOutput.Tags {
   256  		tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value)
   257  	}
   258  
   259  	return tags, nil
   260  }
   261  
   262  // IsPublicSubnet returns True if the subnet identified by the given id in the provided region is public.
   263  func IsPublicSubnet(t testing.TestingT, subnetId string, region string) bool {
   264  	isPublic, err := IsPublicSubnetE(t, subnetId, region)
   265  	require.NoError(t, err)
   266  	return isPublic
   267  }
   268  
   269  // IsPublicSubnetE returns True if the subnet identified by the given id in the provided region is public.
   270  func IsPublicSubnetE(t testing.TestingT, subnetId string, region string) (bool, error) {
   271  	subnetIdFilterName := "association.subnet-id"
   272  
   273  	subnetIdFilter := ec2.Filter{
   274  		Name:   &subnetIdFilterName,
   275  		Values: []*string{&subnetId},
   276  	}
   277  
   278  	client, err := NewEc2ClientE(t, region)
   279  	if err != nil {
   280  		return false, err
   281  	}
   282  
   283  	rts, err := client.DescribeRouteTables(&ec2.DescribeRouteTablesInput{Filters: []*ec2.Filter{&subnetIdFilter}})
   284  	if err != nil {
   285  		return false, err
   286  	}
   287  
   288  	if len(rts.RouteTables) == 0 {
   289  		// Subnets not explicitly associated with any route table are implicitly associated with the main route table
   290  		rts, err = getImplicitRouteTableForSubnetE(t, subnetId, region)
   291  		if err != nil {
   292  			return false, err
   293  		}
   294  	}
   295  
   296  	for _, rt := range rts.RouteTables {
   297  		for _, r := range rt.Routes {
   298  			if strings.HasPrefix(aws.StringValue(r.GatewayId), "igw-") {
   299  				return true, nil
   300  			}
   301  		}
   302  	}
   303  
   304  	return false, nil
   305  }
   306  
   307  func getImplicitRouteTableForSubnetE(t testing.TestingT, subnetId string, region string) (*ec2.DescribeRouteTablesOutput, error) {
   308  	mainRouteFilterName := "association.main"
   309  	mainRouteFilterValue := "true"
   310  	subnetFilterName := "subnet-id"
   311  
   312  	client, err := NewEc2ClientE(t, region)
   313  	if err != nil {
   314  		return nil, err
   315  	}
   316  
   317  	subnetFilter := ec2.Filter{
   318  		Name:   &subnetFilterName,
   319  		Values: []*string{&subnetId},
   320  	}
   321  	subnetOutput, err := client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{&subnetFilter}})
   322  	if err != nil {
   323  		return nil, err
   324  	}
   325  	numSubnets := len(subnetOutput.Subnets)
   326  	if numSubnets != 1 {
   327  		return nil, fmt.Errorf("Expected to find one subnet with id %s but found %s", subnetId, strconv.Itoa(numSubnets))
   328  	}
   329  
   330  	mainRouteFilter := ec2.Filter{
   331  		Name:   &mainRouteFilterName,
   332  		Values: []*string{&mainRouteFilterValue},
   333  	}
   334  	vpcFilter := ec2.Filter{
   335  		Name:   aws.String(vpcIDFilterName),
   336  		Values: []*string{subnetOutput.Subnets[0].VpcId},
   337  	}
   338  	return client.DescribeRouteTables(&ec2.DescribeRouteTablesInput{Filters: []*ec2.Filter{&mainRouteFilter, &vpcFilter}})
   339  }
   340  
   341  // GetRandomPrivateCidrBlock gets a random CIDR block from the range of acceptable private IP addresses per RFC 1918
   342  // (https://tools.ietf.org/html/rfc1918#section-3)
   343  // The routingPrefix refers to the "/28" in 1.2.3.4/28.
   344  // Note that, as written, this function will return a subset of all valid ranges. Since we will probably use this function
   345  // mostly for generating random CIDR ranges for VPCs and Subnets, having comprehensive set coverage is not essential.
   346  func GetRandomPrivateCidrBlock(routingPrefix int) string {
   347  
   348  	var o1, o2, o3, o4 int
   349  
   350  	switch routingPrefix {
   351  	case 32:
   352  		o1 = random.RandomInt([]int{10, 172, 192})
   353  
   354  		switch o1 {
   355  		case 10:
   356  			o2 = random.Random(0, 255)
   357  			o3 = random.Random(0, 255)
   358  			o4 = random.Random(0, 255)
   359  		case 172:
   360  			o2 = random.Random(16, 31)
   361  			o3 = random.Random(0, 255)
   362  			o4 = random.Random(0, 255)
   363  		case 192:
   364  			o2 = 168
   365  			o3 = random.Random(0, 255)
   366  			o4 = random.Random(0, 255)
   367  		}
   368  
   369  	case 31, 30, 29, 28, 27, 26, 25:
   370  		fallthrough
   371  	case 24:
   372  		o1 = random.RandomInt([]int{10, 172, 192})
   373  
   374  		switch o1 {
   375  		case 10:
   376  			o2 = random.Random(0, 255)
   377  			o3 = random.Random(0, 255)
   378  			o4 = 0
   379  		case 172:
   380  			o2 = 16
   381  			o3 = 0
   382  			o4 = 0
   383  		case 192:
   384  			o2 = 168
   385  			o3 = 0
   386  			o4 = 0
   387  		}
   388  	case 23, 22, 21, 20, 19:
   389  		fallthrough
   390  	case 18:
   391  		o1 = random.RandomInt([]int{10, 172, 192})
   392  
   393  		switch o1 {
   394  		case 10:
   395  			o2 = 0
   396  			o3 = 0
   397  			o4 = 0
   398  		case 172:
   399  			o2 = 16
   400  			o3 = 0
   401  			o4 = 0
   402  		case 192:
   403  			o2 = 168
   404  			o3 = 0
   405  			o4 = 0
   406  		}
   407  	}
   408  	return fmt.Sprintf("%d.%d.%d.%d/%d", o1, o2, o3, o4, routingPrefix)
   409  }
   410  
   411  // GetFirstTwoOctets gets the first two octets from a CIDR block.
   412  func GetFirstTwoOctets(cidrBlock string) string {
   413  	ipAddr := strings.Split(cidrBlock, "/")[0]
   414  	octets := strings.Split(ipAddr, ".")
   415  	return octets[0] + "." + octets[1]
   416  }