github.com/coreos/mantle@v0.13.0/platform/api/aws/network.go (about)

     1  // Copyright 2018 CoreOS, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aws
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/aws/aws-sdk-go/aws"
    21  	"github.com/aws/aws-sdk-go/aws/awserr"
    22  	"github.com/aws/aws-sdk-go/service/ec2"
    23  )
    24  
    25  // getSecurityGroupID gets a security group matching the given name.
    26  // If the security group does not exist, it's created.
    27  func (a *API) getSecurityGroupID(name string) (string, error) {
    28  	// using a Filter on group-name rather than the explicit GroupNames parameter
    29  	// disentangles this call from checking only inside of the default VPC
    30  	sgIds, err := a.ec2.DescribeSecurityGroups(&ec2.DescribeSecurityGroupsInput{
    31  		Filters: []*ec2.Filter{
    32  			{
    33  				Name:   aws.String("group-name"),
    34  				Values: []*string{&name},
    35  			},
    36  		},
    37  	})
    38  
    39  	if len(sgIds.SecurityGroups) == 0 {
    40  		return a.createSecurityGroup(name)
    41  	}
    42  
    43  	if err != nil {
    44  		return "", fmt.Errorf("unable to get security group named %v: %v", name, err)
    45  	}
    46  
    47  	return *sgIds.SecurityGroups[0].GroupId, nil
    48  }
    49  
    50  // createSecurityGroup creates a security group with tcp/22 access allowed from the
    51  // internet.
    52  func (a *API) createSecurityGroup(name string) (string, error) {
    53  	vpcId, err := a.createVPC()
    54  	if err != nil {
    55  		return "", err
    56  	}
    57  	sg, err := a.ec2.CreateSecurityGroup(&ec2.CreateSecurityGroupInput{
    58  		GroupName:   aws.String(name),
    59  		Description: aws.String("mantle security group for testing"),
    60  		VpcId:       aws.String(vpcId),
    61  	})
    62  	if err != nil {
    63  		return "", err
    64  	}
    65  	plog.Debugf("created security group %v", *sg.GroupId)
    66  
    67  	allowedIngresses := []ec2.AuthorizeSecurityGroupIngressInput{
    68  		{
    69  			// SSH access from the public internet
    70  			// Full access from inside the same security group
    71  			GroupId: sg.GroupId,
    72  			IpPermissions: []*ec2.IpPermission{
    73  				{
    74  					IpProtocol: aws.String("tcp"),
    75  					IpRanges: []*ec2.IpRange{
    76  						{
    77  							CidrIp: aws.String("0.0.0.0/0"),
    78  						},
    79  					},
    80  					FromPort: aws.Int64(22),
    81  					ToPort:   aws.Int64(22),
    82  				},
    83  				{
    84  					IpProtocol: aws.String("tcp"),
    85  					FromPort:   aws.Int64(1),
    86  					ToPort:     aws.Int64(65535),
    87  					UserIdGroupPairs: []*ec2.UserIdGroupPair{
    88  						{
    89  							GroupId: sg.GroupId,
    90  							VpcId:   &vpcId,
    91  						},
    92  					},
    93  				},
    94  				{
    95  					IpProtocol: aws.String("udp"),
    96  					FromPort:   aws.Int64(1),
    97  					ToPort:     aws.Int64(65535),
    98  					UserIdGroupPairs: []*ec2.UserIdGroupPair{
    99  						{
   100  							GroupId: sg.GroupId,
   101  							VpcId:   &vpcId,
   102  						},
   103  					},
   104  				},
   105  				{
   106  					IpProtocol: aws.String("icmp"),
   107  					FromPort:   aws.Int64(-1),
   108  					ToPort:     aws.Int64(-1),
   109  					UserIdGroupPairs: []*ec2.UserIdGroupPair{
   110  						{
   111  							GroupId: sg.GroupId,
   112  							VpcId:   &vpcId,
   113  						},
   114  					},
   115  				},
   116  			},
   117  		},
   118  	}
   119  
   120  	for _, input := range allowedIngresses {
   121  		_, err := a.ec2.AuthorizeSecurityGroupIngress(&input)
   122  
   123  		if err != nil {
   124  			// We created the SG but can't add all the needed rules, let's try to
   125  			// bail gracefully
   126  			_, delErr := a.ec2.DeleteSecurityGroup(&ec2.DeleteSecurityGroupInput{
   127  				GroupId: sg.GroupId,
   128  			})
   129  			if delErr != nil {
   130  				return "", fmt.Errorf("created sg %v (%v) but couldn't authorize it. Manual deletion may be required: %v", *sg.GroupId, name, err)
   131  			}
   132  			return "", fmt.Errorf("created sg %v (%v), but couldn't authorize it and thus deleted it: %v", *sg.GroupId, name, err)
   133  		}
   134  	}
   135  	return *sg.GroupId, err
   136  }
   137  
   138  // createVPC creates a VPC with an IPV4 CidrBlock of 172.31.0.0/16
   139  func (a *API) createVPC() (string, error) {
   140  	vpc, err := a.ec2.CreateVpc(&ec2.CreateVpcInput{
   141  		CidrBlock: aws.String("172.31.0.0/16"),
   142  	})
   143  	if err != nil {
   144  		return "", fmt.Errorf("creating VPC: %v", err)
   145  	}
   146  	if vpc.Vpc == nil || vpc.Vpc.VpcId == nil {
   147  		return "", fmt.Errorf("vpc was nil after creation")
   148  	}
   149  	err = a.tagCreatedByMantle([]string{*vpc.Vpc.VpcId})
   150  	if err != nil {
   151  		return "", err
   152  	}
   153  
   154  	_, err = a.ec2.ModifyVpcAttribute(&ec2.ModifyVpcAttributeInput{
   155  		EnableDnsHostnames: &ec2.AttributeBooleanValue{
   156  			Value: aws.Bool(true),
   157  		},
   158  		VpcId: vpc.Vpc.VpcId,
   159  	})
   160  	if err != nil {
   161  		return "", fmt.Errorf("enabling DNS Hostnames VPC attribute: %v", err)
   162  	}
   163  	_, err = a.ec2.ModifyVpcAttribute(&ec2.ModifyVpcAttributeInput{
   164  		EnableDnsSupport: &ec2.AttributeBooleanValue{
   165  			Value: aws.Bool(true),
   166  		},
   167  		VpcId: vpc.Vpc.VpcId,
   168  	})
   169  	if err != nil {
   170  		return "", fmt.Errorf("enabling DNS Support VPC attribute: %v", err)
   171  	}
   172  
   173  	routeTable, err := a.createRouteTable(*vpc.Vpc.VpcId)
   174  	if err != nil {
   175  		return "", fmt.Errorf("creating RouteTable: %v", err)
   176  	}
   177  
   178  	err = a.createSubnets(*vpc.Vpc.VpcId, routeTable)
   179  	if err != nil {
   180  		return "", fmt.Errorf("creating subnets: %v", err)
   181  	}
   182  
   183  	return *vpc.Vpc.VpcId, nil
   184  }
   185  
   186  // createRouteTable creates a RouteTable with a local target for destination
   187  // 172.31.0.0/16 as well as an InternetGateway for destination 0.0.0.0/0
   188  func (a *API) createRouteTable(vpcId string) (string, error) {
   189  	rt, err := a.ec2.CreateRouteTable(&ec2.CreateRouteTableInput{
   190  		VpcId: &vpcId,
   191  	})
   192  	if err != nil {
   193  		return "", err
   194  	}
   195  	if rt.RouteTable == nil || rt.RouteTable.RouteTableId == nil {
   196  		return "", fmt.Errorf("route table was nil after creation")
   197  	}
   198  
   199  	err = a.tagCreatedByMantle([]string{*rt.RouteTable.RouteTableId})
   200  	if err != nil {
   201  		return "", err
   202  	}
   203  
   204  	igw, err := a.createInternetGateway(vpcId)
   205  	if err != nil {
   206  		return "", fmt.Errorf("creating internet gateway: %v", err)
   207  	}
   208  
   209  	_, err = a.ec2.CreateRoute(&ec2.CreateRouteInput{
   210  		DestinationCidrBlock: aws.String("0.0.0.0/0"),
   211  		GatewayId:            aws.String(igw),
   212  		RouteTableId:         rt.RouteTable.RouteTableId,
   213  	})
   214  	if err != nil {
   215  		return "", fmt.Errorf("creating remote route: %v", err)
   216  	}
   217  
   218  	return *rt.RouteTable.RouteTableId, nil
   219  }
   220  
   221  // creates an InternetGateway and attaches it to the given VPC
   222  func (a *API) createInternetGateway(vpcId string) (string, error) {
   223  	igw, err := a.ec2.CreateInternetGateway(&ec2.CreateInternetGatewayInput{})
   224  	if err != nil {
   225  		return "", err
   226  	}
   227  	if igw.InternetGateway == nil || igw.InternetGateway.InternetGatewayId == nil {
   228  		return "", fmt.Errorf("internet gateway was nil")
   229  	}
   230  	err = a.tagCreatedByMantle([]string{*igw.InternetGateway.InternetGatewayId})
   231  	if err != nil {
   232  		return "", err
   233  	}
   234  	_, err = a.ec2.AttachInternetGateway(&ec2.AttachInternetGatewayInput{
   235  		InternetGatewayId: igw.InternetGateway.InternetGatewayId,
   236  		VpcId:             &vpcId,
   237  	})
   238  	if err != nil {
   239  		return "", fmt.Errorf("attaching internet gateway to vpc: %v", err)
   240  	}
   241  	return *igw.InternetGateway.InternetGatewayId, nil
   242  }
   243  
   244  // createSubnets creates a subnet in each availability zone for the region
   245  // that is associated with the given VPC associated with the given RouteTable
   246  func (a *API) createSubnets(vpcId, routeTableId string) error {
   247  	azs, err := a.ec2.DescribeAvailabilityZones(&ec2.DescribeAvailabilityZonesInput{})
   248  	if err != nil {
   249  		return fmt.Errorf("retrieving availability zones: %v", err)
   250  	}
   251  
   252  	for i, az := range azs.AvailabilityZones {
   253  		// 16 is the maximum amount of zones possible when giving them a /20
   254  		// CIDR range inside of a /16 network.
   255  		if i > 15 {
   256  			return nil
   257  		}
   258  
   259  		if az.ZoneName == nil {
   260  			continue
   261  		}
   262  
   263  		name := *az.ZoneName
   264  		sub, err := a.ec2.CreateSubnet(&ec2.CreateSubnetInput{
   265  			AvailabilityZone: aws.String(name),
   266  			VpcId:            &vpcId,
   267  			// Increment the CIDR block by 16 every time
   268  			CidrBlock: aws.String(fmt.Sprintf("172.31.%d.0/20", i*16)),
   269  		})
   270  		if err != nil {
   271  			// Some availability zones get returned but cannot have subnets
   272  			// created inside of them
   273  			if awsErr, ok := (err).(awserr.Error); ok && awsErr.Code() == "InvalidParameterValue" {
   274  				continue
   275  			}
   276  			return fmt.Errorf("creating subnet: %v", err)
   277  		}
   278  		if sub.Subnet == nil || sub.Subnet.SubnetId == nil {
   279  			return fmt.Errorf("subnet was nil after creation")
   280  		}
   281  		err = a.tagCreatedByMantle([]string{*sub.Subnet.SubnetId})
   282  		if err != nil {
   283  			return err
   284  		}
   285  		_, err = a.ec2.ModifySubnetAttribute(&ec2.ModifySubnetAttributeInput{
   286  			SubnetId: sub.Subnet.SubnetId,
   287  			MapPublicIpOnLaunch: &ec2.AttributeBooleanValue{
   288  				Value: aws.Bool(true),
   289  			},
   290  		})
   291  		if err != nil {
   292  			return err
   293  		}
   294  
   295  		_, err = a.ec2.AssociateRouteTable(&ec2.AssociateRouteTableInput{
   296  			RouteTableId: &routeTableId,
   297  			SubnetId:     sub.Subnet.SubnetId,
   298  		})
   299  		if err != nil {
   300  			return fmt.Errorf("associating subnet with route table: %v", err)
   301  		}
   302  	}
   303  
   304  	return nil
   305  }
   306  
   307  // getSubnetID gets a subnet for the given VPC.
   308  func (a *API) getSubnetID(vpc string) (string, error) {
   309  	subIds, err := a.ec2.DescribeSubnets(&ec2.DescribeSubnetsInput{
   310  		Filters: []*ec2.Filter{
   311  			{
   312  				Name:   aws.String("vpc-id"),
   313  				Values: []*string{&vpc},
   314  			},
   315  		},
   316  	})
   317  	if err != nil {
   318  		return "", fmt.Errorf("unable to get subnets for vpc %v: %v", vpc, err)
   319  	}
   320  	for _, id := range subIds.Subnets {
   321  		if id.SubnetId != nil {
   322  			return *id.SubnetId, nil
   323  		}
   324  	}
   325  	return "", fmt.Errorf("no subnets found for vpc %v", vpc)
   326  }
   327  
   328  // getVPCID gets a VPC for the given security group
   329  func (a *API) getVPCID(sgId string) (string, error) {
   330  	sgs, err := a.ec2.DescribeSecurityGroups(&ec2.DescribeSecurityGroupsInput{
   331  		GroupIds: []*string{&sgId},
   332  	})
   333  	if err != nil {
   334  		return "", fmt.Errorf("listing vpc's: %v", err)
   335  	}
   336  	for _, sg := range sgs.SecurityGroups {
   337  		if sg.VpcId != nil {
   338  			return *sg.VpcId, nil
   339  		}
   340  	}
   341  	return "", fmt.Errorf("no vpc found for security group %v", sgId)
   342  }