github.com/weaveworks/common@v0.0.0-20230728070032-dd9e68f319d5/tools/scheduler/main.py (about)

     1  import collections
     2  import json
     3  import logging
     4  import operator
     5  import re
     6  
     7  import flask
     8  from oauth2client.client import GoogleCredentials
     9  from googleapiclient import discovery
    10  
    11  from google.appengine.api import urlfetch
    12  from google.appengine.ext import ndb
    13  
    14  app = flask.Flask('scheduler')
    15  app.debug = True
    16  
    17  # We use exponential moving average to record
    18  # test run times.  Higher alpha discounts historic
    19  # observations faster.
    20  alpha = 0.3
    21  
    22  
    23  class Test(ndb.Model):
    24      total_run_time = ndb.FloatProperty(default=0.)  # Not total, but a EWMA
    25      total_runs = ndb.IntegerProperty(default=0)
    26  
    27      def parallelism(self):
    28          name = self.key.string_id()
    29          m = re.search('(\d+)_test.sh$', name)
    30          if m is None:
    31              return 1
    32          else:
    33              return int(m.group(1))
    34  
    35      def cost(self):
    36          p = self.parallelism()
    37          logging.info("Test %s has parallelism %d and avg run time %s",
    38                       self.key.string_id(), p, self.total_run_time)
    39          return self.parallelism() * self.total_run_time
    40  
    41  
    42  class Schedule(ndb.Model):
    43      shards = ndb.JsonProperty()
    44  
    45  
    46  @app.route('/record/<path:test_name>/<runtime>', methods=['POST'])
    47  @ndb.transactional
    48  def record(test_name, runtime):
    49      test = Test.get_by_id(test_name)
    50      if test is None:
    51          test = Test(id=test_name)
    52      test.total_run_time = (test.total_run_time *
    53                             (1 - alpha)) + (float(runtime) * alpha)
    54      test.total_runs += 1
    55      test.put()
    56      return ('', 204)
    57  
    58  
    59  @app.route(
    60      '/schedule/<test_run>/<int:shard_count>/<int:shard>', methods=['POST'])
    61  def schedule(test_run, shard_count, shard):
    62      # read tests from body
    63      test_names = flask.request.get_json(force=True)['tests']
    64  
    65      # first see if we have a scedule already
    66      schedule_id = "%s-%d" % (test_run, shard_count)
    67      schedule = Schedule.get_by_id(schedule_id)
    68      if schedule is not None:
    69          return flask.json.jsonify(tests=schedule.shards[str(shard)])
    70  
    71      # if not, do simple greedy algorithm
    72      test_times = ndb.get_multi(
    73          ndb.Key(Test, test_name) for test_name in test_names)
    74  
    75      def avg(test):
    76          if test is not None:
    77              return test.cost()
    78          return 1
    79  
    80      test_times = [(test_name, avg(test))
    81                    for test_name, test in zip(test_names, test_times)]
    82      test_times_dict = dict(test_times)
    83      test_times.sort(key=operator.itemgetter(1))
    84  
    85      shards = {i: [] for i in xrange(shard_count)}
    86      while test_times:
    87          test_name, time = test_times.pop()
    88  
    89          # find shortest shard and put it in that
    90          s, _ = min(
    91              ((i, sum(test_times_dict[t] for t in shards[i]))
    92               for i in xrange(shard_count)),
    93              key=operator.itemgetter(1))
    94  
    95          shards[s].append(test_name)
    96  
    97      # atomically insert or retrieve existing schedule
    98      schedule = Schedule.get_or_insert(schedule_id, shards=shards)
    99      return flask.json.jsonify(tests=schedule.shards[str(shard)])
   100  
   101  
   102  FIREWALL_REGEXES = [
   103      re.compile(
   104          r'^(?P<network>\w+)-allow-(?P<type>\w+)-(?P<build>\d+)-(?P<shard>\d+)$'
   105      ),
   106      re.compile(r'^(?P<network>\w+)-(?P<build>\d+)-(?P<shard>\d+)-allow-'
   107                 r'(?P<type>[\w\-]+)$'),
   108  ]
   109  NAME_REGEXES = [
   110      re.compile(pat)
   111      for pat in (
   112          r'^host(?P<index>\d+)-(?P<build>\d+)-(?P<shard>\d+)$',
   113          r'^host(?P<index>\d+)-(?P<project>[a-zA-Z0-9-]+)-(?P<build>\d+)'
   114          r'-(?P<shard>\d+)$',
   115          r'^test-(?P<build>\d+)-(?P<shard>\d+)-(?P<index>\d+)$', )
   116  ]
   117  
   118  
   119  def _matches_any_regex(name, regexes):
   120      for regex in regexes:
   121          matches = regex.match(name)
   122          if matches:
   123              return matches
   124  
   125  
   126  # See also: https://circleci.com/account/api
   127  CIRCLE_CI_API_TOKEN = 'cffb83afd920cfa109cbd3e9eecb7511a2d18bb9'
   128  
   129  # N.B.: When adding a project below, please ensure:
   130  # - its CircleCI project is either public, or is followed by the user attached
   131  #   to the above API token
   132  # - user positive-cocoa-90213@appspot.gserviceaccount.com has "Compute Admin"
   133  #   access to its GCP project (or any other role including
   134  #   compute.instances.list/delete and compute.firewalls.list/delete)
   135  PROJECTS = [
   136      ('weaveworks/weave', 'weave-net-tests', 'us-central1-a', True, None),
   137      ('weaveworks/weave', 'positive-cocoa-90213', 'us-central1-a', True, None),
   138      ('weaveworks/scope', 'scope-integration-tests', 'us-central1-a', False,
   139       None),
   140      ('weaveworks/wks', 'wks-tests', 'us-central1-a', True,
   141       CIRCLE_CI_API_TOKEN),
   142  ]
   143  
   144  
   145  @app.route('/tasks/gc')
   146  def gc():
   147      # Get list of running VMs, pick build id out of VM name
   148      credentials = GoogleCredentials.get_application_default()
   149      compute = discovery.build(
   150          'compute', 'v1', credentials=credentials, cache_discovery=False)
   151  
   152      for repo, project, zone, gc_fw, circleci_api_token in PROJECTS:
   153          gc_project(compute, repo, project, zone, gc_fw, circleci_api_token)
   154  
   155      return "Done"
   156  
   157  
   158  def gc_project(compute, repo, project, zone, gc_fw, circleci_api_token):
   159      logging.info("GCing %s, %s, %s", repo, project, zone)
   160      # Get list of builds, filter down to running builds:
   161      running = _get_running_builds(repo, circleci_api_token)
   162      # Stop VMs for builds that aren't running:
   163      _gc_compute_engine_instances(compute, project, zone, running)
   164      # Remove firewall rules for builds that aren't running:
   165      if gc_fw:
   166          _gc_firewall_rules(compute, project, running)
   167  
   168  
   169  def _get_running_builds(repo, circleci_api_token):
   170      if circleci_api_token:
   171          url = 'https://circleci.com/api/v1/project/%s?circle-token=%s' % (
   172              repo, circleci_api_token)
   173      else:
   174          url = 'https://circleci.com/api/v1/project/%s' % repo
   175      result = urlfetch.fetch(url, headers={'Accept': 'application/json'})
   176      if result.status_code != 200:
   177          raise RuntimeError(
   178              'Failed to get builds for "%s". URL: %s, Status: %s. Response: %s'
   179              % (repo, url, result.status_code, result.content))
   180      builds = json.loads(result.content)
   181      running = {
   182          build['build_num']
   183          for build in builds if not build.get('stop_time')
   184      }
   185      logging.info("Runnings builds: %r", running)
   186      return running
   187  
   188  
   189  def _get_hosts_by_build(instances):
   190      host_by_build = collections.defaultdict(list)
   191      for instance in instances['items']:
   192          matches = _matches_any_regex(instance['name'], NAME_REGEXES)
   193          if not matches:
   194              continue
   195          host_by_build[int(matches.group('build'))].append(instance['name'])
   196      logging.info("Running VMs by build: %r", host_by_build)
   197      return host_by_build
   198  
   199  
   200  def _gc_compute_engine_instances(compute, project, zone, running):
   201      instances = compute.instances().list(project=project, zone=zone).execute()
   202      if 'items' not in instances:
   203          return
   204      host_by_build = _get_hosts_by_build(instances)
   205      stopped = []
   206      for build, names in host_by_build.iteritems():
   207          if build in running:
   208              continue
   209          for name in names:
   210              stopped.append(name)
   211              logging.info("Stopping VM %s", name)
   212              compute.instances().delete(
   213                  project=project, zone=zone, instance=name).execute()
   214      return stopped
   215  
   216  
   217  def _gc_firewall_rules(compute, project, running):
   218      firewalls = compute.firewalls().list(project=project).execute()
   219      if 'items' not in firewalls:
   220          return
   221      for firewall in firewalls['items']:
   222          matches = _matches_any_regex(firewall['name'], FIREWALL_REGEXES)
   223          if not matches:
   224              continue
   225          if int(matches.group('build')) in running:
   226              continue
   227          logging.info("Deleting firewall rule %s", firewall['name'])
   228          compute.firewalls().delete(
   229              project=project, firewall=firewall['name']).execute()