github.com/rochacon/deis@v1.0.2-0.20150903015341-6839b592a1ff/contrib/linode/apply-firewall.py (about)

     1  #!/usr/bin/env python
     2  """
     3  Apply a "Security Group" to the members of an etcd cluster.
     4  
     5  Usage: apply-firewall.py
     6  """
     7  import os
     8  import re
     9  import string
    10  import argparse
    11  from threading import Thread
    12  import uuid
    13  
    14  import colorama
    15  from colorama import Fore, Style
    16  import paramiko
    17  import requests
    18  import sys
    19  import yaml
    20  
    21  
    22  def get_nodes_from_args(args):
    23      if args.discovery_url is not None:
    24          return get_nodes_from_discovery_url(args.discovery_url)
    25  
    26      return get_nodes_from_discovery_url(get_discovery_url_from_user_data())
    27  
    28  
    29  def get_nodes_from_discovery_url(discovery_url):
    30      try:
    31          nodes = []
    32          json = requests.get(discovery_url).json()
    33          discovery_nodes = json['node']['nodes']
    34          for node in discovery_nodes:
    35              value = node['value']
    36              ip = re.search('([0-9]{1,3}\.){3}[0-9]{1,3}', value).group(0)
    37              nodes.append(ip)
    38          return nodes
    39      except:
    40          raise IOError('Could not load nodes from discovery url ' + discovery_url)
    41  
    42  
    43  def get_discovery_url_from_user_data():
    44      name = 'linode-user-data.yaml'
    45      log_info('Loading discovery url from ' + name)
    46      try:
    47          current_dir = os.path.dirname(__file__)
    48          user_data_file = file(os.path.abspath(os.path.join(current_dir, name)), 'r')
    49          return re.search('--discovery (http\S+)', user_data_file.read()).group(1)
    50      except:
    51          raise IOError('Could not load discovery url from ' + name)
    52  
    53  
    54  def validate_ip_address(ip):
    55      return True if re.match('([0-9]{1,3}\.){3}[0-9]{1,3}', ip) else False
    56  
    57  
    58  def get_firewall_contents(node_ips, private=False):
    59      rules_template_text = """*filter
    60  :INPUT DROP [0:0]
    61  :FORWARD DROP [0:0]
    62  :OUTPUT ACCEPT [0:0]
    63  :DOCKER - [0:0]
    64  :Firewall-INPUT - [0:0]
    65  -A INPUT -j Firewall-INPUT
    66  -A FORWARD -j Firewall-INPUT
    67  -A Firewall-INPUT -i lo -j ACCEPT
    68  -A Firewall-INPUT -p icmp --icmp-type echo-reply -j ACCEPT
    69  -A Firewall-INPUT -p icmp --icmp-type destination-unreachable -j ACCEPT
    70  -A Firewall-INPUT -p icmp --icmp-type time-exceeded -j ACCEPT
    71  # Ping
    72  -A Firewall-INPUT -p icmp --icmp-type echo-request -j ACCEPT
    73  # Accept any established connections
    74  -A Firewall-INPUT -m conntrack --ctstate  ESTABLISHED,RELATED -j ACCEPT
    75  # Enable the traffic between the nodes of the cluster
    76  -A Firewall-INPUT -s $node_ips -j ACCEPT
    77  # Allow connections from docker container
    78  -A Firewall-INPUT -i docker0 -j ACCEPT
    79  # Accept ssh, http, https and git
    80  -A Firewall-INPUT -m conntrack --ctstate NEW -m multiport$multiport_private -p tcp --dports 22,2222,80,443 -j ACCEPT
    81  # Log and drop everything else
    82  -A Firewall-INPUT -j REJECT
    83  COMMIT
    84  """
    85  
    86      multiport_private = ' -s 192.168.0.0/16' if private else ''
    87  
    88      rules_template = string.Template(rules_template_text)
    89      return rules_template.substitute(node_ips=string.join(node_ips, ','), multiport_private=multiport_private)
    90  
    91  
    92  def apply_rules_to_all(host_ips, rules, private_key):
    93      pkey = detect_and_create_private_key(private_key)
    94  
    95      threads = []
    96      for ip in host_ips:
    97          t = Thread(target=apply_rules, args=(ip, rules, pkey))
    98          t.setDaemon(False)
    99          t.start()
   100          threads.append(t)
   101      for thread in threads:
   102          thread.join()
   103  
   104  
   105  def detect_and_create_private_key(private_key):
   106      private_key_text = private_key.read()
   107      private_key.seek(0)
   108      if '-----BEGIN RSA PRIVATE KEY-----' in private_key_text:
   109          return paramiko.RSAKey.from_private_key(private_key)
   110      elif '-----BEGIN DSA PRIVATE KEY-----' in private_key_text:
   111          return paramiko.DSSKey.from_private_key(private_key)
   112      else:
   113          raise ValueError('Invalid private key file ' + private_key.name)
   114  
   115  
   116  def apply_rules(host_ip, rules, private_key):
   117      # connect to the server via ssh
   118      ssh = paramiko.SSHClient()
   119      ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
   120      ssh.connect(host_ip, username='core', allow_agent=False, look_for_keys=False, pkey=private_key)
   121  
   122      # copy the rules to the temp directory
   123      temp_file = '/tmp/' + str(uuid.uuid4())
   124  
   125      ssh.open_sftp()
   126      sftp = ssh.open_sftp()
   127      sftp.open(temp_file, 'w').write(rules)
   128  
   129      # move the rules in to place and enable and run the iptables-restore.service
   130      commands = [
   131          'sudo mv ' + temp_file + ' /var/lib/iptables/rules-save',
   132          'sudo chown root:root /var/lib/iptables/rules-save',
   133          'sudo systemctl enable iptables-restore.service',
   134          'sudo systemctl start iptables-restore.service'
   135      ]
   136  
   137      for command in commands:
   138          stdin, stdout, stderr = ssh.exec_command(command)
   139          stdout.channel.recv_exit_status()
   140  
   141      ssh.close()
   142  
   143      log_success('Applied rule to ' + host_ip)
   144  
   145  
   146  def main():
   147      colorama.init()
   148  
   149      parser = argparse.ArgumentParser(description='Apply a "Security Group" to a Deis cluster')
   150      parser.add_argument('--private-key', required=True, type=file, dest='private_key', help='Cluster SSH Private Key')
   151      parser.add_argument('--private', action='store_true', dest='private', help='Only allow access to the cluster from the private network')
   152      parser.add_argument('--discovery-url', dest='discovery_url', help='Etcd discovery url')
   153      parser.add_argument('--hosts', nargs='+', dest='hosts', help='The IP addresses of the hosts to apply rules to')
   154      args = parser.parse_args()
   155  
   156      nodes = get_nodes_from_args(args)
   157      hosts = args.hosts if args.hosts is not None else nodes
   158  
   159      node_ips = []
   160      for ip in nodes:
   161          if validate_ip_address(ip):
   162              node_ips.append(ip)
   163          else:
   164              log_warning('Invalid IP will not be added to security group: ' + ip)
   165  
   166      if not len(node_ips) > 0:
   167          raise ValueError('No valid IP addresses in security group.')
   168  
   169      host_ips = []
   170      for ip in hosts:
   171          if validate_ip_address(ip):
   172              host_ips.append(ip)
   173          else:
   174              log_warning('Host has invalid IP address: ' + ip)
   175  
   176      if not len(host_ips) > 0:
   177          raise ValueError('No valid host addresses.')
   178  
   179      log_info('Generating iptables rules...')
   180      rules = get_firewall_contents(node_ips, args.private)
   181      log_success('Generated rules:')
   182      log_debug(rules)
   183  
   184      log_info('Applying rules...')
   185      apply_rules_to_all(host_ips, rules, args.private_key)
   186      log_success('Done!')
   187  
   188  
   189  def log_debug(message):
   190      print(Style.DIM + Fore.MAGENTA + message + Fore.RESET + Style.RESET_ALL)
   191  
   192  
   193  def log_info(message):
   194      print(Fore.CYAN + message + Fore.RESET)
   195  
   196  
   197  def log_warning(message):
   198      print(Fore.YELLOW + message + Fore.RESET)
   199  
   200  
   201  def log_success(message):
   202      print(Style.BRIGHT + Fore.GREEN + message + Fore.RESET + Style.RESET_ALL)
   203  
   204  
   205  def log_error(message):
   206      print(Style.BRIGHT + Fore.RED + message + Fore.RESET + Style.RESET_ALL)
   207  
   208  if __name__ == "__main__":
   209      try:
   210          main()
   211      except Exception as e:
   212          log_error(e.message)
   213          sys.exit(1)