github.com/grumpyhome/grumpy@v0.3.1-0.20201208125205-7b775405bdf1/grumpy-tools-src/grumpy_tools/compiler/shard_test.py (about)

     1  # Copyright 2016 Google Inc. All Rights Reserved.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     http://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  
    15  """Wrapper for unit tests that loads a subset of all test methods."""
    16  
    17  from __future__ import unicode_literals
    18  
    19  import argparse
    20  import random
    21  import re
    22  import sys
    23  import unittest
    24  
    25  
    26  class _ShardAction(argparse.Action):
    27  
    28    def __call__(self, parser, args, values, option_string=None):
    29      match = re.match(r'(\d+)of(\d+)$', values)
    30      if not match:
    31        raise argparse.ArgumentError(self, 'bad shard spec: {}'.format(values))
    32      shard = int(match.group(1))
    33      count = int(match.group(2))
    34      if shard < 1 or count < 1 or shard > count:
    35        raise argparse.ArgumentError(self, 'bad shard spec: {}'.format(values))
    36      setattr(args, self.dest, (shard, count))
    37  
    38  
    39  class _ShardTestLoader(unittest.TestLoader):
    40  
    41    def __init__(self, shard, count):
    42      super(_ShardTestLoader, self).__init__()
    43      self.shard = shard
    44      self.count = count
    45  
    46    def getTestCaseNames(self, test_case_cls):
    47      names = super(_ShardTestLoader, self).getTestCaseNames(test_case_cls)
    48      state = random.getstate()
    49      random.seed(self.count)
    50      random.shuffle(names)
    51      random.setstate(state)
    52      n = len(names)
    53      # self.shard is one-based.
    54      return names[(self.shard - 1) * n / self.count:self.shard * n / self.count]
    55  
    56  
    57  class _ShardTestRunner(object):
    58  
    59    def run(self, test):
    60      result = unittest.TestResult()
    61      unittest.registerResult(result)
    62      test(result)
    63      for kind, errors in [('FAIL', result.failures), ('ERROR', result.errors)]:
    64        for test, err in errors:
    65          sys.stderr.write('{} {}\n{}'.format(test, kind, err))
    66      return result
    67  
    68  
    69  def main():
    70    parser = argparse.ArgumentParser()
    71    parser.add_argument('--shard', default=(1, 1), action=_ShardAction)
    72    parser.add_argument('unittest_args', nargs='*')
    73    args = parser.parse_args()
    74    unittest.main(argv=[sys.argv[0]] + args.unittest_args,
    75                  testLoader=_ShardTestLoader(*args.shard),
    76                  testRunner=_ShardTestRunner)