github.com/deis/deis@v1.13.5-0.20170519182049-1d9e59fbdbfc/contrib/linode/apply-firewall.py (about)

     1  #!/usr/bin/env python
     2  """
     3  Apply a "Security Group" to the members of an etcd cluster.
     4  
     5  Usage: apply-firewall.py
     6  """
     7  import re
     8  import string
     9  import argparse
    10  import threading
    11  from threading import Thread
    12  import uuid
    13  
    14  import paramiko
    15  import requests
    16  import sys
    17  import yaml
    18  
    19  from linodeapi import LinodeApiCommand
    20  import linodeutils 
    21  
    22  class FirewallCommand(LinodeApiCommand):
    23  
    24  
    25      def get_nodes_from_args(self):
    26          if not self.discovery_url:
    27              self.discovery_url = self.get_discovery_url_from_user_data()
    28          return self.get_nodes_from_discovery_url(self.discovery_url)
    29      
    30      def get_nodes_from_discovery_url(self, discovery_url):
    31          try:
    32              nodes = []
    33              json = requests.get(discovery_url).json()
    34              discovery_nodes = json['node']['nodes']
    35              for node in discovery_nodes:
    36                  value = node['value']
    37                  ip = re.search('([0-9]{1,3}\.){3}[0-9]{1,3}', value).group(0)
    38                  nodes.append(ip)
    39              return nodes
    40          except:
    41              raise IOError('Could not load nodes from discovery url ' + discovery_url)
    42  
    43  
    44      def get_discovery_url_from_user_data(self):
    45          name = 'linode-user-data.yaml'
    46          linodeutils.log_info('Loading discovery url from ' + name)
    47          try:
    48              user_data_file = linodeutils.get_file(name)
    49              user_data_yaml = yaml.safe_load(user_data_file)
    50              return user_data_yaml['coreos']['etcd2']['discovery']
    51          except:
    52              raise IOError('Could not load discovery url from ' + name)
    53  
    54  
    55      def validate_ip_address(self, ip):
    56          return True if re.match('([0-9]{1,3}\.){3}[0-9]{1,3}', ip) else False
    57  
    58  
    59      def get_firewall_contents(self, node_ips):
    60          rules_template_text = """*filter
    61  :INPUT DROP [0:0]
    62  :FORWARD DROP [0:0]
    63  :OUTPUT ACCEPT [0:0]
    64  :DOCKER - [0:0]
    65  :Firewall-INPUT - [0:0]
    66  -A INPUT -j Firewall-INPUT
    67  -A FORWARD -j Firewall-INPUT
    68  -A Firewall-INPUT -i lo -j ACCEPT
    69  -A Firewall-INPUT -p icmp --icmp-type echo-reply -j ACCEPT
    70  -A Firewall-INPUT -p icmp --icmp-type destination-unreachable -j ACCEPT
    71  -A Firewall-INPUT -p icmp --icmp-type time-exceeded -j ACCEPT
    72  # Ping
    73  -A Firewall-INPUT -p icmp --icmp-type echo-request -j ACCEPT
    74  # Accept any established connections
    75  -A Firewall-INPUT -m conntrack --ctstate  ESTABLISHED,RELATED -j ACCEPT
    76  # Enable the traffic between the nodes of the cluster
    77  -A Firewall-INPUT -s $node_ips -j ACCEPT
    78  # Allow connections from docker container
    79  -A Firewall-INPUT -i docker0 -j ACCEPT
    80  # Accept ssh, http, https and git
    81  -A Firewall-INPUT -m conntrack --ctstate NEW -m multiport$multiport_private -p tcp --dports 22,2222,80,443$add_new_nodes -j ACCEPT
    82  # Log and drop everything else
    83  -A Firewall-INPUT -j REJECT
    84  COMMIT
    85  """
    86  
    87          multiport_private = ' -s 192.168.0.0/16' if self.private else ''
    88          add_new_nodes = ',2379,2380' if self.adding_new_nodes else ''
    89          
    90          rules_template = string.Template(rules_template_text)
    91          return rules_template.substitute(node_ips=string.join(node_ips, ','), multiport_private=multiport_private, add_new_nodes=add_new_nodes)
    92  
    93  
    94      def apply_rules_to_all(self, host_ips, rules):
    95          pkey = self.detect_and_create_private_key()
    96          
    97          threads = []
    98          for ip in host_ips:
    99              t = Thread(target=self.apply_rules, args=(ip, rules, pkey))
   100              t.setDaemon(False)
   101              t.start()
   102              threads.append(t)
   103          for thread in threads:
   104              thread.join()
   105  
   106  
   107      def detect_and_create_private_key(self):
   108          private_key_text = self.private_key.read()
   109          self.private_key.seek(0)
   110          if '-----BEGIN RSA PRIVATE KEY-----' in private_key_text:
   111              return paramiko.RSAKey.from_private_key(self.private_key)
   112          elif '-----BEGIN DSA PRIVATE KEY-----' in private_key_text:
   113              return paramiko.DSSKey.from_private_key(self.private_key)
   114          else:
   115              raise ValueError('Invalid private key file ' + self.private_key.name)
   116  
   117  
   118      def apply_rules(self, host_ip, rules, private_key):
   119          # connect to the server via ssh
   120          ssh = paramiko.SSHClient()
   121          ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
   122          ssh.connect(host_ip, username='core', allow_agent=False, look_for_keys=False, pkey=private_key)
   123  
   124          # copy the rules to the temp directory
   125          temp_file = '/tmp/' + str(uuid.uuid4())
   126  
   127          ssh.open_sftp()
   128          sftp = ssh.open_sftp()
   129          sftp.open(temp_file, 'w').write(rules)
   130  
   131          # move the rules in to place and enable and run the iptables-restore.service
   132          commands = [
   133              'sudo mv ' + temp_file + ' /var/lib/iptables/rules-save',
   134              'sudo chown root:root /var/lib/iptables/rules-save',
   135              'sudo systemctl enable iptables-restore.service',
   136              'sudo systemctl start iptables-restore.service'
   137          ]
   138  
   139          for command in commands:
   140              stdin, stdout, stderr = ssh.exec_command(command)
   141              stdout.channel.recv_exit_status()
   142  
   143          ssh.close()
   144  
   145          linodeutils.log_success('Applied rule to ' + host_ip)
   146  
   147      def acquire_linode_ips(self):
   148          linodeutils.log_info('Getting info for Linodes from display group: ' + self.node_display_group)
   149          deis_grp = self.request('linode.list')
   150          deis_linodeids = [l.get('LINODEID','') for l in deis_grp if l.get('LPM_DISPLAYGROUP', '') == self.node_display_group]
   151          deis_grp_ips = self.request('linode.ip.list')
   152          self.deis_privateips = [ip.get('IPADDRESS','') for ip in deis_grp_ips if (ip.get('LINODEID','') in deis_linodeids) and (ip.get('ISPUBLIC', 1) == 0)]
   153          self.deis_publicips = [ip.get('IPADDRESS','') for ip in deis_grp_ips if (ip.get('LINODEID','') in deis_linodeids) and (ip.get('ISPUBLIC', 0) == 1)]
   154  
   155      def run(self):
   156          #NOTE: defaults to using display group, then manual input (via nodes/hosts), then discovery_url
   157          if self.node_display_group:
   158              self.acquire_linode_ips()
   159              nodes = self.deis_privateips
   160              hosts = self.deis_publicips
   161          else:
   162              nodes = self.nodes if self.nodes is not None else self.get_nodes_from_args()
   163              hosts = self.hosts if self.hosts is not None else nodes
   164  
   165          node_ips = []
   166          for ip in nodes:
   167              if self.validate_ip_address(ip):
   168                  node_ips.append(ip)
   169              else:
   170                  linodeutils.log_warning('Invalid IP will not be added to security group: ' + ip)
   171  
   172          if not len(node_ips) > 0:
   173              raise ValueError('No valid IP addresses in security group.')
   174  
   175          host_ips = []
   176          for ip in hosts:
   177              if self.validate_ip_address(ip):
   178                  host_ips.append(ip)
   179              else:
   180                  linodeutils.log_warning('Host has invalid IP address: ' + ip)
   181  
   182          if not len(host_ips) > 0:
   183              raise ValueError('No valid host addresses.')
   184  
   185          linodeutils.log_info('Generating iptables rules...')
   186          rules = self.get_firewall_contents(node_ips)
   187          linodeutils.log_success('Generated rules:')
   188          linodeutils.log_debug(rules)
   189          
   190          linodeutils.log_info('Applying rules...')
   191          self.apply_rules_to_all(host_ips, rules)
   192          linodeutils.log_success('Done!')
   193  
   194  
   195  def main():
   196      linodeutils.init()
   197  
   198      parser = argparse.ArgumentParser(description='Apply a "Security Group" to a Deis cluster')
   199      parser.add_argument('--api-key', dest='linode_api_key', help='Linode API Key')
   200      parser.add_argument('--private-key', required=True, type=file, dest='private_key', help='Cluster SSH Private Key')
   201      parser.add_argument('--private', action='store_true', dest='private', help='Only allow access to the cluster from the private network')
   202      parser.add_argument('--adding-new-nodes', action='store_true', dest='adding_new_nodes', help='When adding new nodes to existing cluster, allows access to etcd')
   203      parser.add_argument('--discovery-url', dest='discovery_url', help='Etcd discovery url')
   204      parser.add_argument('--display-group', required=False, dest='node_display_group', help='Display group (used for Linode IP discovery).')
   205      parser.add_argument('--hosts', nargs='+', dest='hosts', help='The public IP addresses of the hosts to apply rules to (for ssh)')
   206      parser.add_argument('--nodes', nargs='+', dest='nodes', help='The private IP addresses of the hosts (for iptable setup)')
   207      parser.set_defaults(cmd=FirewallCommand)
   208      
   209      args = parser.parse_args()
   210      cmd = args.cmd(args)
   211      args.cmd(args).run()
   212      
   213  if __name__ == "__main__":
   214      try:
   215          main()
   216      except Exception as e:
   217          linodeutils.log_error(e.message)
   218          sys.exit(1)