github.com/grumpyhome/grumpy@v0.3.1-0.20201208125205-7b775405bdf1/grumpy-tools-src/grumpy_tools/compiler/shard_test.py (about) 1 # Copyright 2016 Google Inc. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 15 """Wrapper for unit tests that loads a subset of all test methods.""" 16 17 from __future__ import unicode_literals 18 19 import argparse 20 import random 21 import re 22 import sys 23 import unittest 24 25 26 class _ShardAction(argparse.Action): 27 28 def __call__(self, parser, args, values, option_string=None): 29 match = re.match(r'(\d+)of(\d+)$', values) 30 if not match: 31 raise argparse.ArgumentError(self, 'bad shard spec: {}'.format(values)) 32 shard = int(match.group(1)) 33 count = int(match.group(2)) 34 if shard < 1 or count < 1 or shard > count: 35 raise argparse.ArgumentError(self, 'bad shard spec: {}'.format(values)) 36 setattr(args, self.dest, (shard, count)) 37 38 39 class _ShardTestLoader(unittest.TestLoader): 40 41 def __init__(self, shard, count): 42 super(_ShardTestLoader, self).__init__() 43 self.shard = shard 44 self.count = count 45 46 def getTestCaseNames(self, test_case_cls): 47 names = super(_ShardTestLoader, self).getTestCaseNames(test_case_cls) 48 state = random.getstate() 49 random.seed(self.count) 50 random.shuffle(names) 51 random.setstate(state) 52 n = len(names) 53 # self.shard is one-based. 54 return names[(self.shard - 1) * n / self.count:self.shard * n / self.count] 55 56 57 class _ShardTestRunner(object): 58 59 def run(self, test): 60 result = unittest.TestResult() 61 unittest.registerResult(result) 62 test(result) 63 for kind, errors in [('FAIL', result.failures), ('ERROR', result.errors)]: 64 for test, err in errors: 65 sys.stderr.write('{} {}\n{}'.format(test, kind, err)) 66 return result 67 68 69 def main(): 70 parser = argparse.ArgumentParser() 71 parser.add_argument('--shard', default=(1, 1), action=_ShardAction) 72 parser.add_argument('unittest_args', nargs='*') 73 args = parser.parse_args() 74 unittest.main(argv=[sys.argv[0]] + args.unittest_args, 75 testLoader=_ShardTestLoader(*args.shard), 76 testRunner=_ShardTestRunner)