github.com/grumpyhome/grumpy@v0.3.1-0.20201208125205-7b775405bdf1/grumpy-tools-src/grumpy_tools/compiler/expr_visitor.py (about)

     1  # coding=utf-8
     2  
     3  # Copyright 2016 Google Inc. All Rights Reserved.
     4  #
     5  # Licensed under the Apache License, Version 2.0 (the "License");
     6  # you may not use this file except in compliance with the License.
     7  # You may obtain a copy of the License at
     8  #
     9  #     http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  
    17  """Visitor class for traversing Python expressions."""
    18  
    19  from __future__ import unicode_literals
    20  
    21  import contextlib
    22  import textwrap
    23  
    24  from grumpy_tools.compiler import expr
    25  from grumpy_tools.compiler import util
    26  from pythonparser import algorithm
    27  from pythonparser import ast
    28  
    29  try:
    30    long           # Python 2
    31  except NameError:
    32    long = int     # Python 3
    33  
    34  try:
    35    unicode        # Python 2
    36  except NameError:
    37    unicode = str  # Python 3
    38  
    39  
    40  class ExprVisitor(algorithm.Visitor):
    41    """Builds and returns a Go expression representing the Python nodes."""
    42  
    43    # pylint: disable=invalid-name,missing-docstring
    44  
    45    def __init__(self, stmt_visitor):
    46      self.stmt_visitor = stmt_visitor
    47      self.block = stmt_visitor.block
    48      self.writer = stmt_visitor.writer
    49  
    50    def generic_visit(self, node):
    51      msg = 'expression node not yet implemented: ' + type(node).__name__
    52      raise util.ParseError(node, msg)
    53  
    54    def visit_Attribute(self, node):
    55      with self.visit(node.value) as obj:
    56        attr = self.block.alloc_temp()
    57        self.writer.write_checked_call2(
    58            attr, 'πg.GetAttr(πF, {}, {}, nil)',
    59            obj.expr, self.block.root.intern(node.attr))
    60      return attr
    61  
    62    def visit_BinOp(self, node):
    63      result = self.block.alloc_temp()
    64      with self.visit(node.left) as lhs, self.visit(node.right) as rhs:
    65        op_type = type(node.op)
    66        if op_type in ExprVisitor._BIN_OP_TEMPLATES:
    67          tmpl = ExprVisitor._BIN_OP_TEMPLATES[op_type]
    68          self.writer.write_checked_call2(
    69              result, tmpl, lhs=lhs.expr, rhs=rhs.expr)
    70        else:
    71          msg = 'binary op not implemented: {}'.format(op_type.__name__)
    72          raise util.ParseError(node, msg)
    73      return result
    74  
    75    def visit_BoolOp(self, node):
    76      result = self.block.alloc_temp()
    77      with self.block.alloc_temp('bool') as is_true:
    78        if isinstance(node.op, ast.And):
    79          cond_expr = '!' + is_true.expr
    80        else:
    81          cond_expr = is_true.expr
    82        end_label = self.block.genlabel()
    83        num_values = len(node.values)
    84        for i, n in enumerate(node.values):
    85          with self.visit(n) as v:
    86            self.writer.write('{} = {}'.format(result.expr, v.expr))
    87          if i < num_values - 1:
    88            self.writer.write_checked_call2(
    89                is_true, 'πg.IsTrue(πF, {})', result.expr)
    90            self.writer.write_tmpl(textwrap.dedent("""\
    91                if $cond_expr {
    92                \tgoto Label$end_label
    93                }"""), cond_expr=cond_expr, end_label=end_label)
    94      self.writer.write_label(end_label)
    95      return result
    96  
    97    def visit_Call(self, node):
    98      # Build positional arguments.
    99      args = expr.nil_expr
   100      if node.args:
   101        args = self.block.alloc_temp('[]*πg.Object')
   102        self.writer.write('{} = πF.MakeArgs({})'.format(args.expr,
   103                                                        len(node.args)))
   104        for i, n in enumerate(node.args):
   105          with self.visit(n) as a:
   106            self.writer.write('{}[{}] = {}'.format(args.expr, i, a.expr))
   107      varg = expr.nil_expr
   108      if node.starargs:
   109        varg = self.visit(node.starargs)
   110      # Build keyword arguments
   111      keywords = expr.nil_expr
   112      if node.keywords:
   113        values = []
   114        for k in node.keywords:
   115          values.append((util.go_str(k.arg), self.visit(k.value)))
   116        keywords = self.block.alloc_temp('πg.KWArgs')
   117        self.writer.write_tmpl('$keywords = πg.KWArgs{', keywords=keywords.name)
   118        with self.writer.indent_block():
   119          for k, v in values:
   120            with v:
   121              self.writer.write_tmpl('{$name, $value},', name=k, value=v.expr)
   122        self.writer.write('}')
   123      kwargs = expr.nil_expr
   124      if node.kwargs:
   125        kwargs = self.visit(node.kwargs)
   126      # Invoke function with all parameters.
   127      with args, varg, keywords, kwargs, self.visit(node.func) as func:
   128        result = self.block.alloc_temp()
   129        if varg is expr.nil_expr and kwargs is expr.nil_expr:
   130          self.writer.write_checked_call2(result, '{}.Call(πF, {}, {})',
   131                                          func.expr, args.expr, keywords.expr)
   132        else:
   133          self.writer.write_checked_call2(result,
   134                                          'πg.Invoke(πF, {}, {}, {}, {}, {})',
   135                                          func.expr, args.expr, varg.expr,
   136                                          keywords.expr, kwargs.expr)
   137        if node.args:
   138          self.writer.write('πF.FreeArgs({})'.format(args.expr))
   139      return result
   140  
   141    def visit_Compare(self, node):
   142      result = self.block.alloc_temp()
   143      lhs = self.visit(node.left)
   144      n = len(node.ops)
   145      end_label = self.block.genlabel() if n > 1 else None
   146      for i, (op, comp) in enumerate(zip(node.ops, node.comparators)):
   147        rhs = self.visit(comp)
   148        op_type = type(op)
   149        if op_type in ExprVisitor._CMP_OP_TEMPLATES:
   150          tmpl = ExprVisitor._CMP_OP_TEMPLATES[op_type]
   151          self.writer.write_checked_call2(
   152              result, tmpl, lhs=lhs.expr, rhs=rhs.expr)
   153        elif isinstance(op, (ast.In, ast.NotIn)):
   154          with self.block.alloc_temp('bool') as contains:
   155            self.writer.write_checked_call2(
   156                contains, 'πg.Contains(πF, {}, {})', rhs.expr, lhs.expr)
   157            invert = '' if isinstance(op, ast.In) else '!'
   158            self.writer.write('{} = πg.GetBool({}{}).ToObject()'.format(
   159                result.name, invert, contains.expr))
   160        elif isinstance(op, ast.Is):
   161          self.writer.write('{} = πg.GetBool({} == {}).ToObject()'.format(
   162              result.name, lhs.expr, rhs.expr))
   163        elif isinstance(op, ast.IsNot):
   164          self.writer.write('{} = πg.GetBool({} != {}).ToObject()'.format(
   165              result.name, lhs.expr, rhs.expr))
   166        else:
   167          raise AssertionError('unrecognized compare op: {}'.format(
   168              op_type.__name__))
   169        if i < n - 1:
   170          with self.block.alloc_temp('bool') as cond:
   171            self.writer.write_checked_call2(
   172                cond, 'πg.IsTrue(πF, {})', result.expr)
   173            self.writer.write_tmpl(textwrap.dedent("""\
   174                if !$cond {
   175                \tgoto Label$end_label
   176                }"""), cond=cond.expr, end_label=end_label)
   177        lhs.free()
   178        lhs = rhs
   179      rhs.free()
   180      if end_label is not None:
   181        self.writer.write_label(end_label)
   182      return result
   183  
   184    def visit_Dict(self, node):
   185      with self.block.alloc_temp('*πg.Dict') as d:
   186        self.writer.write('{} = πg.NewDict()'.format(d.name))
   187        for k, v in zip(node.keys, node.values):
   188          with self.visit(k) as key, self.visit(v) as value:
   189            self.writer.write_checked_call1('{}.SetItem(πF, {}, {})',
   190                                            d.expr, key.expr, value.expr)
   191        result = self.block.alloc_temp()
   192        self.writer.write('{} = {}.ToObject()'.format(result.name, d.expr))
   193      return result
   194  
   195    def visit_Set(self, node):
   196      with self.block.alloc_temp('*πg.Set') as s:
   197        self.writer.write('{} = πg.NewSet()'.format(s.name))
   198        for e in node.elts:
   199          with self.visit(e) as value:
   200            self.writer.write_checked_call2(expr.blank_var, '{}.Add(πF, {})',
   201                                            s.expr, value.expr)
   202        result = self.block.alloc_temp()
   203        self.writer.write('{} = {}.ToObject()'.format(result.name, s.expr))
   204      return result
   205  
   206    def visit_DictComp(self, node):
   207      result = self.block.alloc_temp()
   208      elt = ast.Tuple(elts=[node.key, node.value])
   209      gen_node = ast.GeneratorExp(
   210          elt=elt, generators=node.generators, loc=node.loc)
   211      with self.visit(gen_node) as gen:
   212        self.writer.write_checked_call2(
   213            result, 'πg.DictType.Call(πF, πg.Args{{{}}}, nil)', gen.expr)
   214      return result
   215  
   216    def visit_ExtSlice(self, node):
   217      result = self.block.alloc_temp()
   218      if len(node.dims) <= util.MAX_DIRECT_TUPLE:
   219        with contextlib.nested(*(self.visit(d) for d in node.dims)) as dims:
   220          self.writer.write('{} = πg.NewTuple{}({}).ToObject()'.format(
   221              result.name, len(dims), ', '.join(d.expr for d in dims)))
   222      else:
   223        with self.block.alloc_temp('[]*πg.Object') as dims:
   224          self.writer.write('{} = make([]*πg.Object, {})'.format(
   225              dims.name, len(node.dims)))
   226          for i, dim in enumerate(node.dims):
   227            with self.visit(dim) as s:
   228              self.writer.write('{}[{}] = {}'.format(dims.name, i, s.expr))
   229          self.writer.write('{} = πg.NewTuple({}...).ToObject()'.format(
   230              result.name, dims.expr))
   231      return result
   232  
   233    def visit_GeneratorExp(self, node):
   234      body = ast.Expr(value=ast.Yield(value=node.elt), loc=node.loc)
   235      for comp_node in reversed(node.generators):
   236        for if_node in reversed(comp_node.ifs):
   237          body = ast.If(test=if_node, body=[body], orelse=[], loc=node.loc)  # pylint: disable=redefined-variable-type
   238        body = ast.For(target=comp_node.target, iter=comp_node.iter,
   239                       body=[body], orelse=[], loc=node.loc)
   240  
   241      args = ast.arguments(args=[], vararg=None, kwarg=None, defaults=[])
   242      node = ast.FunctionDef(name='<generator>', args=args, body=[body])
   243      gen_func = self.stmt_visitor.visit_function_inline(node)
   244      result = self.block.alloc_temp()
   245      self.writer.write_checked_call2(
   246          result, '{}.Call(πF, nil, nil)', gen_func.expr)
   247      return result
   248  
   249    def visit_IfExp(self, node):
   250      else_label, end_label = self.block.genlabel(), self.block.genlabel()
   251      result = self.block.alloc_temp()
   252      with self.visit(node.test) as test, self.block.alloc_temp('bool') as cond:
   253        self.writer.write_checked_call2(
   254            cond, 'πg.IsTrue(πF, {})', test.expr)
   255        self.writer.write_tmpl(textwrap.dedent("""\
   256            if !$cond {
   257            \tgoto Label$else_label
   258            }"""), cond=cond.expr, else_label=else_label)
   259      with self.visit(node.body) as value:
   260        self.writer.write('{} = {}'.format(result.name, value.expr))
   261        self.writer.write('goto Label{}'.format(end_label))
   262      self.writer.write_label(else_label)
   263      with self.visit(node.orelse) as value:
   264        self.writer.write('{} = {}'.format(result.name, value.expr))
   265      self.writer.write_label(end_label)
   266      return result
   267  
   268    def visit_Index(self, node):
   269      result = self.block.alloc_temp()
   270      with self.visit(node.value) as v:
   271        self.writer.write('{} = {}'.format(result.name, v.expr))
   272      return result
   273  
   274    def visit_Lambda(self, node):
   275      ret = ast.Return(value=node.body, loc=node.loc)
   276      func_node = ast.FunctionDef(
   277          name='<lambda>', args=node.args, body=[ret])
   278      return self.stmt_visitor.visit_function_inline(func_node)
   279  
   280    def visit_List(self, node):
   281      with self._visit_seq_elts(node.elts) as elems:
   282        result = self.block.alloc_temp()
   283        self.writer.write('{} = πg.NewList({}...).ToObject()'.format(
   284            result.expr, elems.expr))
   285      return result
   286  
   287    def visit_ListComp(self, node):
   288      result = self.block.alloc_temp()
   289      gen_node = ast.GeneratorExp(
   290          elt=node.elt, generators=node.generators, loc=node.loc)
   291      with self.visit(gen_node) as gen:
   292        self.writer.write_checked_call2(
   293            result, 'πg.ListType.Call(πF, πg.Args{{{}}}, nil)', gen.expr)
   294      return result
   295  
   296    def visit_Name(self, node):
   297      return self.block.resolve_name(self.writer, node.id)
   298  
   299    def visit_Num(self, node):
   300      if isinstance(node.n, int):
   301        expr_str = 'NewInt({})'.format(node.n)
   302      elif isinstance(node.n, long):
   303        a = abs(node.n)
   304        gobytes = ''
   305        while a:
   306          gobytes = hex(int(a&255)) + ',' + gobytes
   307          a >>= 8
   308        expr_str = 'NewLongFromBytes([]byte{{{}}})'.format(gobytes)
   309        if node.n < 0:
   310          expr_str = expr_str + '.Neg()'
   311      elif isinstance(node.n, float):
   312        expr_str = 'NewFloat({})'.format(node.n)
   313      elif isinstance(node.n, complex):
   314        expr_str = 'NewComplex(complex({}, {}))'.format(node.n.real, node.n.imag)
   315      else:
   316        msg = 'number type not yet implemented: ' + type(node.n).__name__
   317        raise util.ParseError(node, msg)
   318      return expr.GeneratedLiteral('πg.' + expr_str + '.ToObject()')
   319  
   320    def visit_Slice(self, node):
   321      result = self.block.alloc_temp()
   322      lower = upper = step = expr.GeneratedLiteral('πg.None')
   323      if node.lower:
   324        lower = self.visit(node.lower)
   325      if node.upper:
   326        upper = self.visit(node.upper)
   327      if node.step:
   328        step = self.visit(node.step)
   329      with lower, upper, step:
   330        self.writer.write_checked_call2(
   331            result, 'πg.SliceType.Call(πF, πg.Args{{{}, {}, {}}}, nil)',
   332            lower.expr, upper.expr, step.expr)
   333      return result
   334  
   335    def visit_Subscript(self, node):
   336      rhs = self.visit(node.slice)
   337      result = self.block.alloc_temp()
   338      with rhs, self.visit(node.value) as lhs:
   339        self.writer.write_checked_call2(result, 'πg.GetItem(πF, {}, {})',
   340                                        lhs.expr, rhs.expr)
   341      return result
   342  
   343    def visit_Str(self, node):
   344      if isinstance(node.s, unicode):
   345        expr_str = 'πg.NewUnicode({}).ToObject()'.format(
   346            util.go_str(node.s.encode('utf-8')))
   347      else:
   348        expr_str = '{}.ToObject()'.format(self.block.root.intern(node.s))
   349      return expr.GeneratedLiteral(expr_str)
   350  
   351    def visit_Tuple(self, node):
   352      result = self.block.alloc_temp()
   353      if len(node.elts) <= util.MAX_DIRECT_TUPLE:
   354        with contextlib.nested(*(self.visit(e) for e in node.elts)) as elts:
   355          self.writer.write('{} = πg.NewTuple{}({}).ToObject()'.format(
   356              result.name, len(elts), ', '.join(e.expr for e in elts)))
   357      else:
   358        with self._visit_seq_elts(node.elts) as elems:
   359          self.writer.write('{} = πg.NewTuple({}...).ToObject()'.format(
   360              result.expr, elems.expr))
   361      return result
   362  
   363    def visit_UnaryOp(self, node):
   364      result = self.block.alloc_temp()
   365      with self.visit(node.operand) as operand:
   366        op_type = type(node.op)
   367        if op_type in ExprVisitor._UNARY_OP_TEMPLATES:
   368          self.writer.write_checked_call2(
   369              result, ExprVisitor._UNARY_OP_TEMPLATES[op_type],
   370              operand=operand.expr)
   371        elif isinstance(node.op, ast.Not):
   372          with self.block.alloc_temp('bool') as is_true:
   373            self.writer.write_checked_call2(
   374                is_true, 'πg.IsTrue(πF, {})', operand.expr)
   375            self.writer.write('{} = πg.GetBool(!{}).ToObject()'.format(
   376                result.name, is_true.expr))
   377        else:
   378          msg = 'unary op not implemented: {}'.format(op_type.__name__)
   379          raise util.ParseError(node, msg)
   380      return result
   381  
   382    def visit_Yield(self, node):
   383      if node.value:
   384        value = self.visit(node.value)
   385      else:
   386        value = expr.GeneratedLiteral('πg.None')
   387      resume_label = self.block.genlabel(is_checkpoint=True)
   388      self.writer.write('πF.PushCheckpoint({})'.format(resume_label))
   389      self.writer.write('return {}, nil'.format(value.expr))
   390      self.writer.write_label(resume_label)
   391      result = self.block.alloc_temp()
   392      self.writer.write('{} = πSent'.format(result.name))
   393      return result
   394  
   395    _BIN_OP_TEMPLATES = {
   396        ast.BitAnd: 'πg.And(πF, {lhs}, {rhs})',
   397        ast.BitOr: 'πg.Or(πF, {lhs}, {rhs})',
   398        ast.BitXor: 'πg.Xor(πF, {lhs}, {rhs})',
   399        ast.Add: 'πg.Add(πF, {lhs}, {rhs})',
   400        ast.Div: 'πg.Div(πF, {lhs}, {rhs})',
   401        # TODO: Support "from __future__ import division".
   402        ast.FloorDiv: 'πg.FloorDiv(πF, {lhs}, {rhs})',
   403        ast.LShift: 'πg.LShift(πF, {lhs}, {rhs})',
   404        ast.Mod: 'πg.Mod(πF, {lhs}, {rhs})',
   405        ast.Mult: 'πg.Mul(πF, {lhs}, {rhs})',
   406        ast.Pow: 'πg.Pow(πF, {lhs}, {rhs})',
   407        ast.RShift: 'πg.RShift(πF, {lhs}, {rhs})',
   408        ast.Sub: 'πg.Sub(πF, {lhs}, {rhs})',
   409    }
   410  
   411    _CMP_OP_TEMPLATES = {
   412        ast.Eq: 'πg.Eq(πF, {lhs}, {rhs})',
   413        ast.Gt: 'πg.GT(πF, {lhs}, {rhs})',
   414        ast.GtE: 'πg.GE(πF, {lhs}, {rhs})',
   415        ast.Lt: 'πg.LT(πF, {lhs}, {rhs})',
   416        ast.LtE: 'πg.LE(πF, {lhs}, {rhs})',
   417        ast.NotEq: 'πg.NE(πF, {lhs}, {rhs})',
   418    }
   419  
   420    _UNARY_OP_TEMPLATES = {
   421        ast.Invert: 'πg.Invert(πF, {operand})',
   422        ast.UAdd: 'πg.Pos(πF, {operand})',
   423        ast.USub: 'πg.Neg(πF, {operand})',
   424    }
   425  
   426    def _visit_seq_elts(self, elts):
   427      result = self.block.alloc_temp('[]*πg.Object')
   428      self.writer.write('{} = make([]*πg.Object, {})'.format(
   429          result.expr, len(elts)))
   430      for i, e in enumerate(elts):
   431        with self.visit(e) as elt:
   432          self.writer.write('{}[{}] = {}'.format(result.expr, i, elt.expr))
   433      return result
   434  
   435    def _node_not_implemented(self, node):
   436      msg = 'node not yet implemented: ' + type(node).__name__
   437      raise util.ParseError(node, msg)
   438  
   439    visit_SetComp = _node_not_implemented