github.com/weaveworks/common@v0.0.0-20230728070032-dd9e68f319d5/tools/scheduler/main.py (about) 1 import collections 2 import json 3 import logging 4 import operator 5 import re 6 7 import flask 8 from oauth2client.client import GoogleCredentials 9 from googleapiclient import discovery 10 11 from google.appengine.api import urlfetch 12 from google.appengine.ext import ndb 13 14 app = flask.Flask('scheduler') 15 app.debug = True 16 17 # We use exponential moving average to record 18 # test run times. Higher alpha discounts historic 19 # observations faster. 20 alpha = 0.3 21 22 23 class Test(ndb.Model): 24 total_run_time = ndb.FloatProperty(default=0.) # Not total, but a EWMA 25 total_runs = ndb.IntegerProperty(default=0) 26 27 def parallelism(self): 28 name = self.key.string_id() 29 m = re.search('(\d+)_test.sh$', name) 30 if m is None: 31 return 1 32 else: 33 return int(m.group(1)) 34 35 def cost(self): 36 p = self.parallelism() 37 logging.info("Test %s has parallelism %d and avg run time %s", 38 self.key.string_id(), p, self.total_run_time) 39 return self.parallelism() * self.total_run_time 40 41 42 class Schedule(ndb.Model): 43 shards = ndb.JsonProperty() 44 45 46 @app.route('/record/<path:test_name>/<runtime>', methods=['POST']) 47 @ndb.transactional 48 def record(test_name, runtime): 49 test = Test.get_by_id(test_name) 50 if test is None: 51 test = Test(id=test_name) 52 test.total_run_time = (test.total_run_time * 53 (1 - alpha)) + (float(runtime) * alpha) 54 test.total_runs += 1 55 test.put() 56 return ('', 204) 57 58 59 @app.route( 60 '/schedule/<test_run>/<int:shard_count>/<int:shard>', methods=['POST']) 61 def schedule(test_run, shard_count, shard): 62 # read tests from body 63 test_names = flask.request.get_json(force=True)['tests'] 64 65 # first see if we have a scedule already 66 schedule_id = "%s-%d" % (test_run, shard_count) 67 schedule = Schedule.get_by_id(schedule_id) 68 if schedule is not None: 69 return flask.json.jsonify(tests=schedule.shards[str(shard)]) 70 71 # if not, do simple greedy algorithm 72 test_times = ndb.get_multi( 73 ndb.Key(Test, test_name) for test_name in test_names) 74 75 def avg(test): 76 if test is not None: 77 return test.cost() 78 return 1 79 80 test_times = [(test_name, avg(test)) 81 for test_name, test in zip(test_names, test_times)] 82 test_times_dict = dict(test_times) 83 test_times.sort(key=operator.itemgetter(1)) 84 85 shards = {i: [] for i in xrange(shard_count)} 86 while test_times: 87 test_name, time = test_times.pop() 88 89 # find shortest shard and put it in that 90 s, _ = min( 91 ((i, sum(test_times_dict[t] for t in shards[i])) 92 for i in xrange(shard_count)), 93 key=operator.itemgetter(1)) 94 95 shards[s].append(test_name) 96 97 # atomically insert or retrieve existing schedule 98 schedule = Schedule.get_or_insert(schedule_id, shards=shards) 99 return flask.json.jsonify(tests=schedule.shards[str(shard)]) 100 101 102 FIREWALL_REGEXES = [ 103 re.compile( 104 r'^(?P<network>\w+)-allow-(?P<type>\w+)-(?P<build>\d+)-(?P<shard>\d+)$' 105 ), 106 re.compile(r'^(?P<network>\w+)-(?P<build>\d+)-(?P<shard>\d+)-allow-' 107 r'(?P<type>[\w\-]+)$'), 108 ] 109 NAME_REGEXES = [ 110 re.compile(pat) 111 for pat in ( 112 r'^host(?P<index>\d+)-(?P<build>\d+)-(?P<shard>\d+)$', 113 r'^host(?P<index>\d+)-(?P<project>[a-zA-Z0-9-]+)-(?P<build>\d+)' 114 r'-(?P<shard>\d+)$', 115 r'^test-(?P<build>\d+)-(?P<shard>\d+)-(?P<index>\d+)$', ) 116 ] 117 118 119 def _matches_any_regex(name, regexes): 120 for regex in regexes: 121 matches = regex.match(name) 122 if matches: 123 return matches 124 125 126 # See also: https://circleci.com/account/api 127 CIRCLE_CI_API_TOKEN = 'cffb83afd920cfa109cbd3e9eecb7511a2d18bb9' 128 129 # N.B.: When adding a project below, please ensure: 130 # - its CircleCI project is either public, or is followed by the user attached 131 # to the above API token 132 # - user positive-cocoa-90213@appspot.gserviceaccount.com has "Compute Admin" 133 # access to its GCP project (or any other role including 134 # compute.instances.list/delete and compute.firewalls.list/delete) 135 PROJECTS = [ 136 ('weaveworks/weave', 'weave-net-tests', 'us-central1-a', True, None), 137 ('weaveworks/weave', 'positive-cocoa-90213', 'us-central1-a', True, None), 138 ('weaveworks/scope', 'scope-integration-tests', 'us-central1-a', False, 139 None), 140 ('weaveworks/wks', 'wks-tests', 'us-central1-a', True, 141 CIRCLE_CI_API_TOKEN), 142 ] 143 144 145 @app.route('/tasks/gc') 146 def gc(): 147 # Get list of running VMs, pick build id out of VM name 148 credentials = GoogleCredentials.get_application_default() 149 compute = discovery.build( 150 'compute', 'v1', credentials=credentials, cache_discovery=False) 151 152 for repo, project, zone, gc_fw, circleci_api_token in PROJECTS: 153 gc_project(compute, repo, project, zone, gc_fw, circleci_api_token) 154 155 return "Done" 156 157 158 def gc_project(compute, repo, project, zone, gc_fw, circleci_api_token): 159 logging.info("GCing %s, %s, %s", repo, project, zone) 160 # Get list of builds, filter down to running builds: 161 running = _get_running_builds(repo, circleci_api_token) 162 # Stop VMs for builds that aren't running: 163 _gc_compute_engine_instances(compute, project, zone, running) 164 # Remove firewall rules for builds that aren't running: 165 if gc_fw: 166 _gc_firewall_rules(compute, project, running) 167 168 169 def _get_running_builds(repo, circleci_api_token): 170 if circleci_api_token: 171 url = 'https://circleci.com/api/v1/project/%s?circle-token=%s' % ( 172 repo, circleci_api_token) 173 else: 174 url = 'https://circleci.com/api/v1/project/%s' % repo 175 result = urlfetch.fetch(url, headers={'Accept': 'application/json'}) 176 if result.status_code != 200: 177 raise RuntimeError( 178 'Failed to get builds for "%s". URL: %s, Status: %s. Response: %s' 179 % (repo, url, result.status_code, result.content)) 180 builds = json.loads(result.content) 181 running = { 182 build['build_num'] 183 for build in builds if not build.get('stop_time') 184 } 185 logging.info("Runnings builds: %r", running) 186 return running 187 188 189 def _get_hosts_by_build(instances): 190 host_by_build = collections.defaultdict(list) 191 for instance in instances['items']: 192 matches = _matches_any_regex(instance['name'], NAME_REGEXES) 193 if not matches: 194 continue 195 host_by_build[int(matches.group('build'))].append(instance['name']) 196 logging.info("Running VMs by build: %r", host_by_build) 197 return host_by_build 198 199 200 def _gc_compute_engine_instances(compute, project, zone, running): 201 instances = compute.instances().list(project=project, zone=zone).execute() 202 if 'items' not in instances: 203 return 204 host_by_build = _get_hosts_by_build(instances) 205 stopped = [] 206 for build, names in host_by_build.iteritems(): 207 if build in running: 208 continue 209 for name in names: 210 stopped.append(name) 211 logging.info("Stopping VM %s", name) 212 compute.instances().delete( 213 project=project, zone=zone, instance=name).execute() 214 return stopped 215 216 217 def _gc_firewall_rules(compute, project, running): 218 firewalls = compute.firewalls().list(project=project).execute() 219 if 'items' not in firewalls: 220 return 221 for firewall in firewalls['items']: 222 matches = _matches_any_regex(firewall['name'], FIREWALL_REGEXES) 223 if not matches: 224 continue 225 if int(matches.group('build')) in running: 226 continue 227 logging.info("Deleting firewall rule %s", firewall['name']) 228 compute.firewalls().delete( 229 project=project, firewall=firewall['name']).execute()