github.com/grumpyhome/grumpy@v0.3.1-0.20201208125205-7b775405bdf1/grumpy-tools-src/grumpy_tools/compiler/block_test.py (about)

     1  # coding=utf-8
     2  
     3  # Copyright 2016 Google Inc. All Rights Reserved.
     4  #
     5  # Licensed under the Apache License, Version 2.0 (the "License");
     6  # you may not use this file except in compliance with the License.
     7  # You may obtain a copy of the License at
     8  #
     9  #     http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  
    17  """Tests Package, Block, BlockVisitor and related classes."""
    18  
    19  from __future__ import unicode_literals
    20  
    21  import textwrap
    22  import unittest
    23  
    24  from grumpy_tools.compiler import block
    25  from grumpy_tools.compiler import imputil
    26  from grumpy_tools.compiler import util
    27  import pythonparser
    28  
    29  
    30  class PackageTest(unittest.TestCase):
    31  
    32    def testCreate(self):
    33      package = block.Package('foo/bar/baz')
    34      self.assertEqual(package.name, 'foo/bar/baz')
    35      self.assertEqual(package.alias, 'π_fooΓbarΓbaz')
    36  
    37    def testCreateGrump(self):
    38      package = block.Package('foo/bar/baz', 'myalias')
    39      self.assertEqual(package.name, 'foo/bar/baz')
    40      self.assertEqual(package.alias, 'myalias')
    41  
    42  
    43  class BlockTest(unittest.TestCase):
    44  
    45    def testLoop(self):
    46      b = _MakeModuleBlock()
    47      loop = b.push_loop(None)
    48      self.assertEqual(loop, b.top_loop())
    49      inner_loop = b.push_loop(None)
    50      self.assertEqual(inner_loop, b.top_loop())
    51      b.pop_loop()
    52      self.assertEqual(loop, b.top_loop())
    53  
    54    def testResolveName(self):
    55      module_block = _MakeModuleBlock()
    56      block_vars = {'foo': block.Var('foo', block.Var.TYPE_LOCAL)}
    57      func1_block = block.FunctionBlock(module_block, 'func1', block_vars, False)
    58      block_vars = {'bar': block.Var('bar', block.Var.TYPE_LOCAL)}
    59      func2_block = block.FunctionBlock(func1_block, 'func2', block_vars, False)
    60      block_vars = {'case': block.Var('case', block.Var.TYPE_LOCAL)}
    61      keyword_block = block.FunctionBlock(
    62          module_block, 'keyword_func', block_vars, False)
    63      class1_block = block.ClassBlock(module_block, 'Class1', set())
    64      class2_block = block.ClassBlock(func1_block, 'Class2', set())
    65      self.assertRegexpMatches(self._ResolveName(module_block, 'foo'),
    66                               r'ResolveGlobal\b.*foo')
    67      self.assertRegexpMatches(self._ResolveName(module_block, 'bar'),
    68                               r'ResolveGlobal\b.*bar')
    69      self.assertRegexpMatches(self._ResolveName(module_block, 'baz'),
    70                               r'ResolveGlobal\b.*baz')
    71      self.assertRegexpMatches(self._ResolveName(func1_block, 'foo'),
    72                               r'CheckLocal\b.*foo')
    73      self.assertRegexpMatches(self._ResolveName(func1_block, 'bar'),
    74                               r'ResolveGlobal\b.*bar')
    75      self.assertRegexpMatches(self._ResolveName(func1_block, 'baz'),
    76                               r'ResolveGlobal\b.*baz')
    77      self.assertRegexpMatches(self._ResolveName(func2_block, 'foo'),
    78                               r'CheckLocal\b.*foo')
    79      self.assertRegexpMatches(self._ResolveName(func2_block, 'bar'),
    80                               r'CheckLocal\b.*bar')
    81      self.assertRegexpMatches(self._ResolveName(func2_block, 'baz'),
    82                               r'ResolveGlobal\b.*baz')
    83      self.assertRegexpMatches(self._ResolveName(class1_block, 'foo'),
    84                               r'ResolveClass\(.*, nil, .*foo')
    85      self.assertRegexpMatches(self._ResolveName(class2_block, 'foo'),
    86                               r'ResolveClass\(.*, µfoo, .*foo')
    87      self.assertRegexpMatches(self._ResolveName(keyword_block, 'case'),
    88                               r'CheckLocal\b.*µcase, "case"')
    89  
    90    def _ResolveName(self, b, name):
    91      writer = util.Writer()
    92      b.resolve_name(writer, name)
    93      return writer.getvalue()
    94  
    95  
    96  class BlockVisitorTest(unittest.TestCase):
    97  
    98    def testAssignSingle(self):
    99      visitor = block.BlockVisitor()
   100      visitor.visit(_ParseStmt('foo = 3'))
   101      self.assertEqual(visitor.vars.keys(), ['foo'])
   102      self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')
   103  
   104    def testAssignMultiple(self):
   105      visitor = block.BlockVisitor()
   106      visitor.visit(_ParseStmt('foo = bar = 123'))
   107      self.assertEqual(sorted(visitor.vars.keys()), ['bar', 'foo'])
   108      self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')
   109      self.assertRegexpMatches(visitor.vars['bar'].init_expr, r'UnboundLocal')
   110  
   111    def testAssignTuple(self):
   112      visitor = block.BlockVisitor()
   113      visitor.visit(_ParseStmt('foo, bar = "a", "b"'))
   114      self.assertEqual(sorted(visitor.vars.keys()), ['bar', 'foo'])
   115      self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')
   116      self.assertRegexpMatches(visitor.vars['bar'].init_expr, r'UnboundLocal')
   117  
   118    def testAssignNested(self):
   119      visitor = block.BlockVisitor()
   120      visitor.visit(_ParseStmt('foo, (bar, baz) = "a", ("b", "c")'))
   121      self.assertEqual(sorted(visitor.vars.keys()), ['bar', 'baz', 'foo'])
   122      self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')
   123      self.assertRegexpMatches(visitor.vars['bar'].init_expr, r'UnboundLocal')
   124      self.assertRegexpMatches(visitor.vars['baz'].init_expr, r'UnboundLocal')
   125  
   126    def testAugAssignSingle(self):
   127      visitor = block.BlockVisitor()
   128      visitor.visit(_ParseStmt('foo += 3'))
   129      self.assertEqual(visitor.vars.keys(), ['foo'])
   130      self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')
   131  
   132    def testVisitClassDef(self):
   133      visitor = block.BlockVisitor()
   134      visitor.visit(_ParseStmt('class Foo(object): pass'))
   135      self.assertEqual(visitor.vars.keys(), ['Foo'])
   136      self.assertRegexpMatches(visitor.vars['Foo'].init_expr, r'UnboundLocal')
   137  
   138    def testExceptHandler(self):
   139      visitor = block.BlockVisitor()
   140      visitor.visit(_ParseStmt(textwrap.dedent("""\
   141          try:
   142            pass
   143          except Exception as foo:
   144            pass
   145          except TypeError as bar:
   146            pass""")))
   147      self.assertEqual(sorted(visitor.vars.keys()), ['bar', 'foo'])
   148      self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')
   149      self.assertRegexpMatches(visitor.vars['bar'].init_expr, r'UnboundLocal')
   150  
   151    def testFor(self):
   152      visitor = block.BlockVisitor()
   153      visitor.visit(_ParseStmt('for i in foo: pass'))
   154      self.assertEqual(visitor.vars.keys(), ['i'])
   155      self.assertRegexpMatches(visitor.vars['i'].init_expr, r'UnboundLocal')
   156  
   157    def testFunctionDef(self):
   158      visitor = block.BlockVisitor()
   159      visitor.visit(_ParseStmt('def foo(): pass'))
   160      self.assertEqual(visitor.vars.keys(), ['foo'])
   161      self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')
   162  
   163    def testImport(self):
   164      visitor = block.BlockVisitor()
   165      visitor.visit(_ParseStmt('import foo.bar, baz'))
   166      self.assertEqual(sorted(visitor.vars.keys()), ['baz', 'foo'])
   167      self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')
   168      self.assertRegexpMatches(visitor.vars['baz'].init_expr, r'UnboundLocal')
   169  
   170    def testImportFrom(self):
   171      visitor = block.BlockVisitor()
   172      visitor.visit(_ParseStmt('from foo.bar import baz, qux'))
   173      self.assertEqual(sorted(visitor.vars.keys()), ['baz', 'qux'])
   174      self.assertRegexpMatches(visitor.vars['baz'].init_expr, r'UnboundLocal')
   175      self.assertRegexpMatches(visitor.vars['qux'].init_expr, r'UnboundLocal')
   176  
   177    def testGlobal(self):
   178      visitor = block.BlockVisitor()
   179      visitor.visit(_ParseStmt('global foo, bar'))
   180      self.assertEqual(sorted(visitor.vars.keys()), ['bar', 'foo'])
   181      self.assertIsNone(visitor.vars['foo'].init_expr)
   182      self.assertIsNone(visitor.vars['bar'].init_expr)
   183  
   184    def testGlobalIsParam(self):
   185      visitor = block.BlockVisitor()
   186      visitor.vars['foo'] = block.Var('foo', block.Var.TYPE_PARAM, arg_index=0)
   187      self.assertRaisesRegexp(util.ParseError, 'is parameter and global',
   188                              visitor.visit, _ParseStmt('global foo'))
   189  
   190    def testGlobalUsedPriorToDeclaration(self):
   191      node = pythonparser.parse('foo = 42\nglobal foo')
   192      visitor = block.BlockVisitor()
   193      self.assertRaisesRegexp(util.ParseError, 'used prior to global declaration',
   194                              visitor.generic_visit, node)
   195  
   196  
   197  class FunctionBlockVisitorTest(unittest.TestCase):
   198  
   199    def testArgs(self):
   200      func = _ParseStmt('def foo(bar, baz, *args, **kwargs): pass')
   201      visitor = block.FunctionBlockVisitor(func)
   202      self.assertIn('bar', visitor.vars)
   203      self.assertIn('baz', visitor.vars)
   204      self.assertIn('args', visitor.vars)
   205      self.assertIn('kwargs', visitor.vars)
   206      self.assertRegexpMatches(visitor.vars['bar'].init_expr, r'Args\[0\]')
   207      self.assertRegexpMatches(visitor.vars['baz'].init_expr, r'Args\[1\]')
   208      self.assertRegexpMatches(visitor.vars['args'].init_expr, r'Args\[2\]')
   209      self.assertRegexpMatches(visitor.vars['kwargs'].init_expr, r'Args\[3\]')
   210  
   211    def testArgsDuplicate(self):
   212      func = _ParseStmt('def foo(bar, baz, bar=None): pass')
   213      self.assertRaisesRegexp(util.ParseError, 'duplicate argument',
   214                              block.FunctionBlockVisitor, func)
   215  
   216    def testYield(self):
   217      visitor = block.FunctionBlockVisitor(_ParseStmt('def foo(): pass'))
   218      visitor.visit(_ParseStmt('yield "foo"'))
   219      self.assertTrue(visitor.is_generator)
   220  
   221    def testYieldExpr(self):
   222      visitor = block.FunctionBlockVisitor(_ParseStmt('def foo(): pass'))
   223      visitor.visit(_ParseStmt('foo = (yield)'))
   224      self.assertTrue(visitor.is_generator)
   225      self.assertEqual(sorted(visitor.vars.keys()), ['foo'])
   226      self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')
   227  
   228    def testTupleArgs(self):
   229      func = _ParseStmt('def foo((bar, baz)): pass')
   230      visitor = block.FunctionBlockVisitor(func)
   231      self.assertEqual(len(visitor.vars), 3)
   232      self.assertEqual(len([v for v in visitor.vars if visitor.vars[v].type == block.Var.TYPE_TUPLE_PARAM]), 2)
   233      self.assertIn('bar', visitor.vars)
   234      self.assertIn('baz', visitor.vars)
   235  
   236  def _MakeModuleBlock():
   237    importer = imputil.Importer(None, '__main__', '/tmp/foo.py', False)
   238    return block.ModuleBlock(importer, '__main__', '<test>', '',
   239                             imputil.FutureFeatures())
   240  
   241  
   242  def _ParseStmt(stmt_str):
   243    return pythonparser.parse(stmt_str).body[0]
   244  
   245  
   246  if __name__ == '__main__':
   247    unittest.main()