github.com/google/grumpy@v0.0.0-20171122020858-3ec87959189c/third_party/stdlib/unittest_loader.py (about)

     1  """Loading unittests."""
     2  
     3  import os
     4  import re
     5  import sys
     6  import traceback
     7  import types
     8  
     9  # from functools import cmp_to_key as _CmpToKey
    10  # from fnmatch import fnmatch
    11  import functools
    12  import fnmatch as _fnmatch
    13  _CmpToKey = functools.cmp_to_key
    14  fnmatch = _fnmatch.fnmatch
    15  
    16  # from . import case, suite
    17  import unittest_case as case
    18  import unittest_suite as suite
    19  
    20  __unittest = True
    21  
    22  # what about .pyc or .pyo (etc)
    23  # we would need to avoid loading the same tests multiple times
    24  # from '.py', '.pyc' *and* '.pyo'
    25  VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
    26  
    27  
    28  def _make_failed_import_test(name, suiteClass):
    29      message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
    30      return _make_failed_test('ModuleImportFailure', name, ImportError(message),
    31                               suiteClass)
    32  
    33  def _make_failed_load_tests(name, exception, suiteClass):
    34      return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
    35  
    36  def _make_failed_test(classname, methodname, exception, suiteClass):
    37      def testFailure(self):
    38          raise exception
    39      attrs = {methodname: testFailure}
    40      TestClass = type(classname, (case.TestCase,), attrs)
    41      return suiteClass((TestClass(methodname),))
    42  
    43  
    44  class TestLoader(object):
    45      """
    46      This class is responsible for loading tests according to various criteria
    47      and returning them wrapped in a TestSuite
    48      """
    49      testMethodPrefix = 'test'
    50      sortTestMethodsUsing = cmp
    51      suiteClass = suite.TestSuite
    52      _top_level_dir = None
    53  
    54      def loadTestsFromTestCase(self, testCaseClass):
    55          """Return a suite of all tests cases contained in testCaseClass"""
    56          if issubclass(testCaseClass, suite.TestSuite):
    57              raise TypeError("Test cases should not be derived from TestSuite." \
    58                                  " Maybe you meant to derive from TestCase?")
    59          testCaseNames = self.getTestCaseNames(testCaseClass)
    60          if not testCaseNames and hasattr(testCaseClass, 'runTest'):
    61              testCaseNames = ['runTest']
    62          loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
    63          return loaded_suite
    64  
    65      def loadTestsFromModule(self, module, use_load_tests=True):
    66          """Return a suite of all tests cases contained in the given module"""
    67          tests = []
    68          for name in dir(module):
    69              obj = getattr(module, name)
    70              if isinstance(obj, type) and issubclass(obj, case.TestCase):
    71                  tests.append(self.loadTestsFromTestCase(obj))
    72  
    73          load_tests = getattr(module, 'load_tests', None)
    74          tests = self.suiteClass(tests)
    75          if use_load_tests and load_tests is not None:
    76              try:
    77                  return load_tests(self, tests, None)
    78              except Exception, e:
    79                  return _make_failed_load_tests(module.__name__, e,
    80                                                 self.suiteClass)
    81          return tests
    82  
    83      def loadTestsFromName(self, name, module=None):
    84          """Return a suite of all tests cases given a string specifier.
    85  
    86          The name may resolve either to a module, a test case class, a
    87          test method within a test case class, or a callable object which
    88          returns a TestCase or TestSuite instance.
    89  
    90          The method optionally resolves the names relative to a given module.
    91          """
    92          parts = name.split('.')
    93          if module is None:
    94              parts_copy = parts[:]
    95              while parts_copy:
    96                  try:
    97                      module = __import__('.'.join(parts_copy))
    98                      break
    99                  except ImportError:
   100                      del parts_copy[-1]
   101                      if not parts_copy:
   102                          raise
   103              parts = parts[1:]
   104          obj = module
   105          for part in parts:
   106              parent, obj = obj, getattr(obj, part)
   107  
   108          if isinstance(obj, types.ModuleType):
   109              return self.loadTestsFromModule(obj)
   110          elif isinstance(obj, type) and issubclass(obj, case.TestCase):
   111              return self.loadTestsFromTestCase(obj)
   112          elif (isinstance(obj, types.UnboundMethodType) and
   113                isinstance(parent, type) and
   114                issubclass(parent, case.TestCase)):
   115              name = parts[-1]
   116              inst = parent(name)
   117              return self.suiteClass([inst])
   118          elif isinstance(obj, suite.TestSuite):
   119              return obj
   120          elif hasattr(obj, '__call__'):
   121              test = obj()
   122              if isinstance(test, suite.TestSuite):
   123                  return test
   124              elif isinstance(test, case.TestCase):
   125                  return self.suiteClass([test])
   126              else:
   127                  raise TypeError("calling %s returned %s, not a test" %
   128                                  (obj, test))
   129          else:
   130              raise TypeError("don't know how to make test from: %s" % obj)
   131  
   132      def loadTestsFromNames(self, names, module=None):
   133          """Return a suite of all tests cases found using the given sequence
   134          of string specifiers. See 'loadTestsFromName()'.
   135          """
   136          suites = [self.loadTestsFromName(name, module) for name in names]
   137          return self.suiteClass(suites)
   138  
   139      def getTestCaseNames(self, testCaseClass):
   140          """Return a sorted sequence of method names found within testCaseClass
   141          """
   142          def isTestMethod(attrname, testCaseClass=testCaseClass,
   143                           prefix=self.testMethodPrefix):
   144              return attrname.startswith(prefix) and \
   145                  hasattr(getattr(testCaseClass, attrname), '__call__')
   146          # testFnNames = filter(isTestMethod, dir(testCaseClass))
   147          testFnNames = [x for x in dir(testCaseClass) if isTestMethod(x)]
   148          if self.sortTestMethodsUsing:
   149              testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
   150          return testFnNames
   151  
   152      def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
   153          """Find and return all test modules from the specified start
   154          directory, recursing into subdirectories to find them. Only test files
   155          that match the pattern will be loaded. (Using shell style pattern
   156          matching.)
   157  
   158          All test modules must be importable from the top level of the project.
   159          If the start directory is not the top level directory then the top
   160          level directory must be specified separately.
   161  
   162          If a test package name (directory with '__init__.py') matches the
   163          pattern then the package will be checked for a 'load_tests' function. If
   164          this exists then it will be called with loader, tests, pattern.
   165  
   166          If load_tests exists then discovery does  *not* recurse into the package,
   167          load_tests is responsible for loading all tests in the package.
   168  
   169          The pattern is deliberately not stored as a loader attribute so that
   170          packages can continue discovery themselves. top_level_dir is stored so
   171          load_tests does not need to pass this argument in to loader.discover().
   172          """
   173          set_implicit_top = False
   174          if top_level_dir is None and self._top_level_dir is not None:
   175              # make top_level_dir optional if called from load_tests in a package
   176              top_level_dir = self._top_level_dir
   177          elif top_level_dir is None:
   178              set_implicit_top = True
   179              top_level_dir = start_dir
   180  
   181          top_level_dir = os.path.abspath(top_level_dir)
   182  
   183          if not top_level_dir in sys.path:
   184              # all test modules must be importable from the top level directory
   185              # should we *unconditionally* put the start directory in first
   186              # in sys.path to minimise likelihood of conflicts between installed
   187              # modules and development versions?
   188              sys.path.insert(0, top_level_dir)
   189          self._top_level_dir = top_level_dir
   190  
   191          is_not_importable = False
   192          if os.path.isdir(os.path.abspath(start_dir)):
   193              start_dir = os.path.abspath(start_dir)
   194              if start_dir != top_level_dir:
   195                  is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
   196          else:
   197              # support for discovery from dotted module names
   198              try:
   199                  __import__(start_dir)
   200              except ImportError:
   201                  is_not_importable = True
   202              else:
   203                  the_module = sys.modules[start_dir]
   204                  top_part = start_dir.split('.')[0]
   205                  start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
   206                  if set_implicit_top:
   207                      self._top_level_dir = self._get_directory_containing_module(top_part)
   208                      sys.path.remove(top_level_dir)
   209  
   210          if is_not_importable:
   211              raise ImportError('Start directory is not importable: %r' % start_dir)
   212  
   213          tests = list(self._find_tests(start_dir, pattern))
   214          return self.suiteClass(tests)
   215  
   216      def _get_directory_containing_module(self, module_name):
   217          module = sys.modules[module_name]
   218          full_path = os.path.abspath(module.__file__)
   219  
   220          if os.path.basename(full_path).lower().startswith('__init__.py'):
   221              return os.path.dirname(os.path.dirname(full_path))
   222          else:
   223              # here we have been given a module rather than a package - so
   224              # all we can do is search the *same* directory the module is in
   225              # should an exception be raised instead
   226              return os.path.dirname(full_path)
   227  
   228      def _get_name_from_path(self, path):
   229          path = os.path.splitext(os.path.normpath(path))[0]
   230  
   231          _relpath = os.path.relpath(path, self._top_level_dir)
   232          assert not os.path.isabs(_relpath), "Path must be within the project"
   233          assert not _relpath.startswith('..'), "Path must be within the project"
   234  
   235          name = _relpath.replace(os.path.sep, '.')
   236          return name
   237  
   238      def _get_module_from_name(self, name):
   239          __import__(name)
   240          return sys.modules[name]
   241  
   242      def _match_path(self, path, full_path, pattern):
   243          # override this method to use alternative matching strategy
   244          return fnmatch(path, pattern)
   245  
   246      def _find_tests(self, start_dir, pattern):
   247          """Used by discovery. Yields test suites it loads."""
   248          paths = os.listdir(start_dir)
   249  
   250          for path in paths:
   251              full_path = os.path.join(start_dir, path)
   252              if os.path.isfile(full_path):
   253                  if not VALID_MODULE_NAME.match(path):
   254                      # valid Python identifiers only
   255                      continue
   256                  if not self._match_path(path, full_path, pattern):
   257                      continue
   258                  # if the test file matches, load it
   259                  name = self._get_name_from_path(full_path)
   260                  try:
   261                      module = self._get_module_from_name(name)
   262                  except:
   263                      yield _make_failed_import_test(name, self.suiteClass)
   264                  else:
   265                      mod_file = os.path.abspath(getattr(module, '__file__', full_path))
   266                      realpath = os.path.splitext(os.path.realpath(mod_file))[0]
   267                      fullpath_noext = os.path.splitext(os.path.realpath(full_path))[0]
   268                      if realpath.lower() != fullpath_noext.lower():
   269                          module_dir = os.path.dirname(realpath)
   270                          mod_name = os.path.splitext(os.path.basename(full_path))[0]
   271                          expected_dir = os.path.dirname(full_path)
   272                          msg = ("%r module incorrectly imported from %r. Expected %r. "
   273                                 "Is this module globally installed?")
   274                          raise ImportError(msg % (mod_name, module_dir, expected_dir))
   275                      yield self.loadTestsFromModule(module)
   276              elif os.path.isdir(full_path):
   277                  if not os.path.isfile(os.path.join(full_path, '__init__.py')):
   278                      continue
   279  
   280                  load_tests = None
   281                  tests = None
   282                  if fnmatch(path, pattern):
   283                      # only check load_tests if the package directory itself matches the filter
   284                      name = self._get_name_from_path(full_path)
   285                      package = self._get_module_from_name(name)
   286                      load_tests = getattr(package, 'load_tests', None)
   287                      tests = self.loadTestsFromModule(package, use_load_tests=False)
   288  
   289                  if load_tests is None:
   290                      if tests is not None:
   291                          # tests loaded from package file
   292                          yield tests
   293                      # recurse into the package
   294                      for test in self._find_tests(full_path, pattern):
   295                          yield test
   296                  else:
   297                      try:
   298                          yield load_tests(self, tests, pattern)
   299                      except Exception, e:
   300                          yield _make_failed_load_tests(package.__name__, e,
   301                                                        self.suiteClass)
   302  
   303  defaultTestLoader = TestLoader()
   304  
   305  
   306  def _makeLoader(prefix, sortUsing, suiteClass=None):
   307      loader = TestLoader()
   308      loader.sortTestMethodsUsing = sortUsing
   309      loader.testMethodPrefix = prefix
   310      if suiteClass:
   311          loader.suiteClass = suiteClass
   312      return loader
   313  
   314  def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
   315      return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
   316  
   317  def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
   318                suiteClass=suite.TestSuite):
   319      return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
   320  
   321  def findTestCases(module, prefix='test', sortUsing=cmp,
   322                    suiteClass=suite.TestSuite):
   323      return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)