gopkg.in/rethinkdb/rethinkdb-go.v6@v6.2.2/internal/gen_tests/gen_tests.py (about)

     1  #!/usr/bin/env python3
     2  # -*- coding: utf-8 -*-
     3  '''Finds yaml tests, converts them to Go tests.'''
     4  from __future__ import print_function
     5  
     6  import sys
     7  import os
     8  import os.path
     9  import re
    10  import time
    11  import ast
    12  from re import sub, match, split, DOTALL
    13  import argparse
    14  import codecs
    15  import logging
    16  import process_polyglot
    17  import parse_polyglot
    18  from process_polyglot import Unhandled, Skip, FatalSkip, SkippedTest
    19  from mako.template import Template
    20  
    21  try:
    22      from cStringIO import StringIO
    23  except ImportError:
    24      from io import StringIO
    25  from collections import namedtuple
    26  
    27  parse_polyglot.printDebug = False
    28  logger = logging.getLogger("convert_tests")
    29  r = None
    30  
    31  
    32  # Individual tests can be ignored by specifying their line numbers from the YAML source files.
    33  TEST_EXCLUSIONS = {
    34      'regression/': None,
    35  
    36      'limits.yaml': [
    37          36, # Produces invalid Go code using `list` and `range`
    38          39, 41, 47, # Rely on above
    39          75, 87, 108, 120, # The fetch() calls hang
    40      ],
    41  
    42      'changefeeds/squash': [
    43          47, # double run
    44      ],
    45  
    46      'arity': [
    47          51, 100, 155, 272, 282, 291, # Malformed syntax produces malformed Go
    48      ],
    49  
    50      'sindex/truncation.rb.yaml': None, # The code generator fails when it encounters (0...n).map{}
    51  
    52      'changefeeds/geo.rb.yaml': None, # The generated code needs to enclose the map keys inside quotation marks
    53  
    54      'aggregation.yaml': [
    55          482, # A nil map key is expected, but Go strings cannot be nil, so actual result is ""
    56      ],
    57  
    58      # This file expects a password-less user `test_user` to exist, but even if it does,
    59      #   - Lines #18 and #27 try to create a database and fail if it already exists, but the YAML doesn't drop the database
    60      #     afterwards, so successive runs always fail.  Unfortunately, other test cases depend on these lines being run.
    61      #   - Lines #49, #54, #60, and #66 use a global run option `user` which Rethink doesn't recognize.
    62      'meta/grant.yaml': None,
    63  
    64      # The lambdas passed to SetWriteHook() must return rethinkdb.Term but the generated code uses interface{} instead.
    65      # These types of queries might be covered by reql_writehook_test.go though.
    66      'mutation/write_hook.yaml': None,
    67  
    68      'changefeeds/idxcopy.yaml': [
    69          28, # The fetch() hangs.
    70      ],
    71  
    72      'changefeeds/now.py_one.yaml': [
    73          5, # Hangs instead of returning an error.
    74      ],
    75  
    76      'times/constructors.yaml': [
    77          59, 61, # Tries to create year 10000 but Rethink now only allows up to 9999.
    78          64, 70, # The expected error message refers to 10000 but Rethink's error actually says 9999.
    79      ],
    80  
    81      'math_logic/logic.yaml': [
    82          143, 147, # Expected an error but got nil.  Don't know if it's a real problem.
    83      ],
    84  
    85      'control.yaml': [
    86          # Attempts to add 1 to a integer variable that was set to tbl.Count().  The Go '+' operator must be used instead of the ReQL '.Add()'.
    87          # Deciding which operator to use requires the variable type, which in turn requires knowing what type each ReQL function returns.
    88          # The robust solution is too much work, and checking for '.Count()' in the definition is a hack.
    89          236,
    90      ],
    91  
    92      'meta/dbs.yaml': [
    93          # For some reason the results include databases created by other tests (eg. 'examples', 'test_cur_all')
    94          6, 18, 31, 37,
    95      ],
    96  
    97      # For some reason the results include databases created by other tests (eg. 'examples', 'test_cur_all')
    98      'meta/composite.py.yaml': None,
    99  }
   100  
   101  GO_KEYWORDS = [
   102      'break', 'case', 'chan', 'const', 'continue', 'default', 'defer', 'else',
   103      'fallthrough', 'for', 'func', 'go', 'goto', 'if', 'import', 'interface',
   104      'map', 'package', 'range', 'return', 'select', 'struct', 'switch', 'type',
   105      'var'
   106  ]
   107  
   108  def main():
   109      logging.basicConfig(format="[%(name)s] %(message)s", level=logging.INFO)
   110      start = time.perf_counter()
   111      args = parse_args()
   112      if args.debug:
   113          logger.setLevel(logging.DEBUG)
   114          logging.getLogger('process_polyglot').setLevel(logging.DEBUG)
   115      elif args.info:
   116          logger.setLevel(logging.INFO)
   117          logging.getLogger('process_polyglot').setLevel(logging.INFO)
   118      else:
   119          logger.root.setLevel(logging.WARNING)
   120      if args.e:
   121          evaluate_snippet(args.e)
   122          exit(0)
   123      global r
   124      r = import_python_driver(args.python_driver_dir)
   125      renderer = Renderer(
   126          invoking_filenames=[
   127              __file__,
   128              process_polyglot.__file__,
   129          ])
   130      full_exclusions = list(file for (file, lines) in TEST_EXCLUSIONS.items() if not lines)
   131      for testfile in process_polyglot.all_yaml_tests(
   132              args.test_dir,
   133              full_exclusions):
   134          logger.info("Working on %s", testfile)
   135  
   136          excluded_lines = set()
   137          for exclusion, lines in TEST_EXCLUSIONS.items():
   138              if exclusion in testfile:
   139                  excluded_lines.update(lines)
   140  
   141          TestFile(
   142              test_dir=args.test_dir,
   143              filename=testfile,
   144              excluded_lines=excluded_lines,
   145              test_output_dir=args.test_output_dir,
   146              renderer=renderer,
   147          ).load().render()
   148      logger.info("Finished in %s seconds", time.perf_counter() - start)
   149  
   150  
   151  def parse_args():
   152      '''Parse command line arguments'''
   153      parser = argparse.ArgumentParser(description=__doc__)
   154      parser.add_argument(
   155          "--test-dir",
   156          help="Directory where yaml tests are",
   157          default=""
   158      )
   159      parser.add_argument(
   160          "--test-output-dir",
   161          help="Directory to render tests to",
   162          default=".",
   163      )
   164      parser.add_argument(
   165          "--python-driver-dir",
   166          help="Where the built python driver is located",
   167          default=""
   168      )
   169      parser.add_argument(
   170          "--test-file",
   171          help="Only convert the specified yaml file",
   172      )
   173      parser.add_argument(
   174          '--debug',
   175          help="Print debug output",
   176          dest='debug',
   177          action='store_true')
   178      parser.set_defaults(debug=False)
   179      parser.add_argument(
   180          '--info',
   181          help="Print info level output",
   182          dest='info',
   183          action='store_true')
   184      parser.set_defaults(info=False)
   185      parser.add_argument(
   186          '-e',
   187          help="Convert an inline python reql to go reql snippet",
   188      )
   189      return parser.parse_args()
   190  
   191  
   192  def import_python_driver(py_driver_dir):
   193      '''Imports the test driver header'''
   194      stashed_path = sys.path
   195      sys.path.insert(0, os.path.realpath(py_driver_dir))
   196      from rethinkdb import r
   197      sys.path = stashed_path
   198      return r
   199  
   200  GoQuery = namedtuple(
   201      'GoQuery',
   202      ('is_value',
   203       'line',
   204       'expected_type',
   205       'expected_line',
   206       'testfile',
   207       'line_num',
   208       'runopts')
   209  )
   210  GoDef = namedtuple(
   211      'GoDef',
   212      ('line',
   213       'varname',
   214       'vartype',
   215       'value',
   216       'run_if_query',
   217       'testfile',
   218       'line_num',
   219       'runopts')
   220  )
   221  Version = namedtuple("Version", "original go")
   222  
   223  GO_DECL = re.compile(r'var (?P<var>\w+) (?P<type>.+) = (?P<value>.*)')
   224  
   225  
   226  def evaluate_snippet(snippet):
   227      '''Just converts a single expression snippet into java'''
   228      try:
   229          parsed = ast.parse(snippet, mode='eval').body
   230      except Exception as e:
   231          return print("Error:", e)
   232      try:
   233          print(ReQLVisitor(smart_bracket=True).convert(parsed))
   234      except Exception as e:
   235          return print("Error:", e)
   236  
   237  
   238  class TestFile(object):
   239      '''Represents a single test file'''
   240  
   241      def __init__(self, test_dir, filename, excluded_lines, test_output_dir, renderer):
   242          self.filename = filename
   243          self.full_path = os.path.join(test_dir, filename)
   244          self.excluded_lines = excluded_lines
   245          self.module_name = filename.split('.')[0].replace('/', '_')
   246          self.test_output_dir = test_output_dir
   247          self.reql_vars = {'r'}
   248          self.renderer = renderer
   249  
   250      def load(self):
   251          '''Load the test file, yaml parse it, extract file-level metadata'''
   252          with open(self.full_path, encoding='utf-8') as f:
   253              parsed_yaml = parse_polyglot.parseYAML(f)
   254          self.description = parsed_yaml.get('desc', 'No description')
   255          self.table_var_names = self.get_varnames(parsed_yaml)
   256          self.reql_vars.update(self.table_var_names)
   257          self.raw_test_data = parsed_yaml['tests']
   258          self.test_generator = process_polyglot.tests_and_defs(
   259              self.filename,
   260              self.raw_test_data,
   261              context=process_polyglot.create_context(r, self.table_var_names),
   262              custom_field='go',
   263          )
   264          return self
   265  
   266      def get_varnames(self, yaml_file):
   267          '''Extract table variable names from yaml variable
   268          They can be specified just space separated, or comma separated'''
   269          raw_var_names = yaml_file.get('table_variable_name', '')
   270          if not raw_var_names:
   271              return set()
   272          return set(re.split(r'[, ]+', raw_var_names))
   273  
   274      def render(self):
   275          '''Renders the converted tests to a runnable test file'''
   276          defs_and_test = ast_to_go(self.test_generator, self.reql_vars, self.excluded_lines)
   277          self.renderer.source_files = [self.full_path]
   278          self.renderer.render(
   279              'template.go',
   280              output_dir=self.test_output_dir,
   281              output_name='reql_' + self.module_name + '_test.go',
   282              dependencies=[self.full_path],
   283              defs_and_test=defs_and_test,
   284              table_var_names=list(sorted(self.table_var_names)),
   285              module_name=camel(self.module_name),
   286              GoQuery=GoQuery,
   287              GoDef=GoDef,
   288              description=self.description,
   289          )
   290  
   291  
   292  def py_to_go_type(py_type):
   293      '''Converts python types to their Go equivalents'''
   294      if py_type is None:
   295          return None
   296      elif isinstance(py_type, str):
   297          # This can be called on something already converted
   298          return py_type
   299      elif py_type.__name__ == 'function':
   300          return 'func()'
   301      elif (py_type.__module__ == 'datetime' and
   302            py_type.__name__ == 'datetime'):
   303          return 'time.Time'
   304      elif py_type.__module__ == 'builtins':
   305          if py_type.__name__.endswith('Error'):
   306              return 'error'
   307  
   308          return {
   309              bool: 'bool',
   310              bytes: '[]byte',
   311              int: 'int',
   312              float: 'float64',
   313              str: 'string',
   314              dict: 'map[interface{}]interface{}',
   315              list: '[]interface{}',
   316              object: 'map[interface{}]interface{}',
   317              type(None): 'interface{}',
   318          }[py_type]
   319      elif py_type.__module__ == 'rethinkdb.ast':
   320          return "r.Term"
   321          # Anomalous non-rule based capitalization in the python driver
   322          # return {
   323          # }.get(py_type.__name__, py_type.__name__)
   324      elif py_type.__module__ == 'rethinkdb.errors':
   325          return py_type.__name__
   326      elif py_type.__module__ == '?test?':
   327          return {
   328              'int_cmp': 'int',
   329              'float_cmp': 'float64',
   330              'err_regex': 'Err',
   331              'partial': 'compare.Expected',
   332              'bag': 'compare.Expected',
   333              'uuid': 'compare.Regex',  # clashes with ast.Uuid
   334          }.get(py_type.__name__, camel(py_type.__name__))
   335      elif py_type.__module__ == 'rethinkdb.query':
   336          # ReQL constants don't have a type; they are just identifiers.
   337          return None
   338      else:
   339          raise Unhandled(
   340              "Don't know how to convert python type {}.{} to Go"
   341              .format(py_type.__module__, py_type.__name__))
   342  
   343  
   344  def def_to_go(item, reql_vars):
   345      if is_reql(item.term.type):
   346          reql_vars.add(item.varname)
   347      try:
   348          if is_reql(item.term.type):
   349              visitor = ReQLVisitor
   350          else:
   351              visitor = GoVisitor
   352          go_line = visitor(reql_vars,
   353                              type_=item.term.type,
   354                              is_def=True,
   355                              ).convert(item.term.ast)
   356      except Skip as skip:
   357          return SkippedTest(line=item.term.line, reason=str(skip))
   358      go_decl = GO_DECL.match(go_line).groupdict()
   359      return GoDef(
   360          line=Version(
   361              original=item.term.line,
   362              go=go_line,
   363          ),
   364          varname=go_decl['var'],
   365          vartype=go_decl['type'],
   366          value=go_decl['value'],
   367          run_if_query=item.run_if_query,
   368          testfile=item.testfile,
   369          line_num=item.line_num,
   370          runopts=convert_runopts(reql_vars, go_decl['type'], item.runopts)
   371      )
   372  
   373  
   374  def query_to_go(item, reql_vars):
   375      if item.runopts is not None:
   376          converted_runopts = convert_runopts(
   377              reql_vars, item.query.type, item.runopts)
   378      else:
   379          converted_runopts = item.runopts
   380      if converted_runopts is None:
   381          converted_runopts = {}
   382      converted_runopts['GeometryFormat'] = '\"raw\"'
   383      if 'GroupFormat' not in converted_runopts:
   384          converted_runopts['GroupFormat'] = '\"map\"'
   385      try:
   386          is_value = False
   387          if not item.query.type.__module__.startswith("rethinkdb.") or (not item.query.type.__module__.startswith("rethinkdb.") and item.query.type.__name__.endswith("Error")):
   388              is_value = True
   389          go_line = ReQLVisitor(
   390              reql_vars, type_=item.query.type).convert(item.query.ast)
   391          if is_reql(item.expected.type):
   392              visitor = ReQLVisitor
   393          else:
   394              visitor = GoVisitor
   395          go_expected_line = visitor(
   396              reql_vars, type_=item.expected.type)\
   397              .convert(item.expected.ast)
   398      except Skip as skip:
   399          return SkippedTest(line=item.query.line, reason=str(skip))
   400      return GoQuery(
   401          is_value=is_value,
   402          line=Version(
   403              original=item.query.line,
   404              go=go_line,
   405          ),
   406          expected_type=py_to_go_type(item.expected.type),
   407          expected_line=Version(
   408              original=item.expected.line,
   409              go=go_expected_line,
   410          ),
   411          testfile=item.testfile,
   412          line_num=item.line_num,
   413          runopts=converted_runopts,
   414      )
   415  
   416  
   417  def ast_to_go(sequence, reql_vars, excluded_lines):
   418      '''Converts the the parsed test data to go source lines using the
   419      visitor classes'''
   420      reql_vars = set(reql_vars)
   421      for item in sequence:
   422          if hasattr(item, 'line_num') and item.line_num in excluded_lines:
   423              logger.info("Skipped %s line %d due to exclusion", item.testfile, item.line_num)
   424          elif type(item) == process_polyglot.Def:
   425              yield def_to_go(item, reql_vars)
   426          elif type(item) == process_polyglot.CustomDef:
   427              yield GoDef(line=Version(item.line, item.line),
   428                            testfile=item.testfile,
   429                            line_num=item.line_num)
   430          elif type(item) == process_polyglot.Query:
   431              yield query_to_go(item, reql_vars)
   432          elif type(item) == SkippedTest:
   433              yield item
   434          else:
   435              assert False, "shouldn't happen, item was {}".format(item)
   436  
   437  
   438  def is_reql(t):
   439      '''Determines if a type is a reql term'''
   440      # Other options for module: builtins, ?test?, datetime
   441      if not hasattr(t, '__module__'):
   442          return True
   443  
   444      return t.__module__ == 'rethinkdb.ast'
   445  
   446  
   447  def escape_string(s, out):
   448      out.write('"')
   449      for codepoint in s:
   450          rpr = repr(codepoint)[1:-1]
   451          if rpr.startswith('\\x'):
   452              rpr = '\\u00' + rpr[2:]
   453          elif rpr == '"':
   454              rpr = r'\"'
   455          out.write(rpr)
   456      out.write('"')
   457  
   458  
   459  def attr_matches(path, node):
   460      '''Helper function. Several places need to know if they are an
   461      attribute of some root object'''
   462      root, name = path.split('.')
   463      ret = is_name(root, node.value) and node.attr == name
   464      return ret
   465  
   466  
   467  def is_name(name, node):
   468      '''Determine if the current attribute node is a Name with the
   469      given name'''
   470      return type(node) == ast.Name and node.id == name
   471  
   472  
   473  def convert_runopts(reql_vars, type_, runopts):
   474      if runopts is None:
   475          return None
   476      return {
   477          camel(key): GoVisitor(
   478              reql_vars, type_=type_).convert(val)
   479          for key, val in runopts.items()
   480      }
   481  
   482  
   483  class GoVisitor(ast.NodeVisitor):
   484      '''Converts python ast nodes into a Go string'''
   485  
   486      def __init__(self,
   487                   reql_vars=frozenset("r"),
   488                   out=None,
   489                   type_=None,
   490                   is_def=False,
   491                   smart_bracket=False,
   492      ):
   493          self.out = StringIO() if out is None else out
   494          self.reql_vars = reql_vars
   495          self.type = py_to_go_type(type_)
   496          self._type = type_
   497          self.is_def = is_def
   498          self.smart_bracket = smart_bracket
   499          super(GoVisitor, self).__init__()
   500          self.write = self.out.write
   501  
   502      def skip(self, message, *args, **kwargs):
   503          cls = Skip
   504          is_fatal = kwargs.pop('fatal', False)
   505          if self.is_def or is_fatal:
   506              cls = FatalSkip
   507          raise cls(message, *args, **kwargs)
   508  
   509      def convert(self, node):
   510          '''Convert a text line to another text line'''
   511          self.visit(node)
   512          return self.out.getvalue()
   513  
   514      def join(self, sep, items):
   515          first = True
   516          for item in items:
   517              if first:
   518                  first = False
   519              else:
   520                  self.write(sep)
   521              self.visit(item)
   522  
   523      def to_str(self, s):
   524          escape_string(s, self.out)
   525  
   526      def cast_null(self, arg):
   527          if (type(arg) == ast.Name and arg.id == 'null') or \
   528             (type(arg) == ast.NameConstant and arg.value == None):
   529              self.write("nil")
   530          else:
   531              self.visit(arg)
   532  
   533      def to_args(self, args, func='', optargs=[]):
   534          optargs_first = False
   535  
   536          self.write("(")
   537          if args:
   538              self.cast_null(args[0])
   539          for arg in args[1:]:
   540              self.write(', ')
   541              self.cast_null(arg)
   542          self.write(")")
   543  
   544          if optargs:
   545              self.to_args_optargs(func, optargs)
   546  
   547      def to_args_optargs(self, func='', optargs=[]):
   548          optarg_aliases = {
   549              'JsOpts': 'JSOpts',
   550              'HttpOpts': 'HTTPOpts',
   551              'Iso8601Opts': 'ISO8601Opts',
   552              'IndexCreateFuncOpts': 'IndexCreateOpts'
   553          }
   554          optarg_field_aliases = {
   555              'nonvoting_replica_tags': 'NonVotingReplicaTags',
   556          }
   557  
   558          if not func:
   559              raise Unhandled("Missing function name")
   560  
   561          optarg_type = camel(func) + 'Opts'
   562          optarg_type = optarg_aliases.get(optarg_type, optarg_type)
   563          optarg_type = 'r.' + optarg_type
   564  
   565          self.write('.OptArgs(')
   566          self.write(optarg_type)
   567          self.write('{')
   568          for optarg in optargs:
   569              # Hack to skip tests that check for unknown opt args,
   570              # this is not possible in Go due to static types
   571              if optarg.arg == 'return_vals':
   572                  self.skip("test not required since optargs are statically typed")
   573                  return
   574              if optarg.arg == 'foo':
   575                  self.skip("test not required since optargs are statically typed")
   576                  return
   577              if type(optarg.value) == ast.Name and optarg.value.id == 'null':
   578                  self.skip("test not required since go does not support null optargs")
   579                  return
   580  
   581  
   582              field_name = optarg_field_aliases.get(optarg.arg, camel(optarg.arg))
   583  
   584              self.write(field_name)
   585              self.write(": ")
   586              self.cast_null(optarg.value)
   587              self.write(', ')
   588          self.write('})')
   589  
   590      def generic_visit(self, node):
   591          logger.error("While translating: %s", ast.dump(node))
   592          logger.error("Got as far as: %s", ''.join(self.out))
   593          raise Unhandled("Don't know what this thing is: " + str(type(node)))
   594  
   595      def visit_Assign(self, node):
   596          if len(node.targets) != 1:
   597              Unhandled("We only support assigning to one variable")
   598          var = node.targets[0].id
   599          self.write("var " + var + " ")
   600  
   601          if is_reql(self._type):
   602              self.write('r.Term')
   603          else:
   604              self.write(self.type)
   605          self.write(" = ")
   606          if is_reql(self._type):
   607              ReQLVisitor(self.reql_vars,
   608                          out=self.out,
   609                          type_=self.type,
   610                          is_def=True,
   611                          ).visit(node.value)
   612          elif var == 'upper_limit': # Manually set value since value in test causes an error
   613              self.write('2<<52 - 1')
   614          elif var == 'lower_limit': # Manually set value since value in test causes an error
   615              self.write('1 - 2<<52')
   616          else:
   617              self.visit(node.value)
   618  
   619      def visit_Str(self, node):
   620          if node.s == 'ReqlServerCompileError':
   621              node.s = 'ReqlCompileError'
   622          # Hack to skip irrelevant tests
   623          if match(".*Expected .* argument", node.s):
   624              self.skip("argument checks not supported")
   625          if match(".*argument .* must", node.s):
   626              self.skip("argument checks not supported")
   627              return
   628          if node.s == 'Object keys must be strings.*':
   629              self.skip('the Go driver automatically converts object keys to strings')
   630              return
   631          if node.s.startswith('\'module\' object has no attribute '):
   632              self.skip('test not required since terms are statically typed')
   633              return
   634          self.to_str(node.s)
   635  
   636      def visit_Bytes(self, node, skip_prefix=False, skip_suffix=False):
   637          if not skip_prefix:
   638              self.write("[]byte{")
   639          for i, byte in enumerate(node.s):
   640              self.write(str(byte))
   641              if i < len(node.s)-1:
   642                  self.write(",")
   643          if not skip_suffix:
   644              self.write("}")
   645          else:
   646              self.write(", ")
   647  
   648      def visit_Name(self, node):
   649          name = node.id
   650          if name == 'frozenset':
   651              self.skip("can't convert frozensets to GroupedData yet")
   652          if name in GO_KEYWORDS:
   653              name += '_'
   654          self.write({
   655              'True': 'true',
   656              'False': 'false',
   657              'None': 'nil',
   658              'nil': 'nil',
   659              'null': 'nil',
   660              'float': 'float64',
   661              'datetime': 'Ast', # Use helper method instead
   662              'len': 'maybeLen',
   663              'AnythingIsFine': 'compare.AnythingIsFine',
   664              'bag': 'compare.UnorderedMatch',
   665              'partial': 'compare.PartialMatch',
   666              'uuid': 'compare.IsUUID',
   667              'regex': 'compare.MatchesRegexp',
   668              }.get(name, name))
   669  
   670      def visit_arg(self, node):
   671          self.write(node.arg)
   672          self.write(" ")
   673          if is_reql(self._type):
   674              self.write("r.Term")
   675          else:
   676              self.write("interface{}")
   677  
   678      def visit_NameConstant(self, node):
   679          if node.value is None:
   680              self.write("nil")
   681          elif node.value is True:
   682              self.write("true")
   683          elif node.value is False:
   684              self.write("false")
   685          else:
   686              raise Unhandled(
   687                  "Don't know NameConstant with value %s" % node.value)
   688  
   689      def visit_Attribute(self, node, emit_parens=True):
   690          skip_parent = False
   691          if attr_matches("r.ast", node):
   692              # The java driver doesn't have that namespace, so we skip
   693              # the `r.` prefix and create an ast class member in the
   694              # test file. So stuff like `r.ast.rqlTzinfo(...)` converts
   695              # to `ast.rqlTzinfo(...)`
   696              skip_parent = True
   697  
   698  
   699          if not skip_parent:
   700              self.visit(node.value)
   701              self.write(".")
   702  
   703          attr = ReQLVisitor.convertTermName(node.attr)
   704          self.write(attr)
   705  
   706      def visit_Num(self, node):
   707          self.write(repr(node.n))
   708          if not isinstance(node.n, float):
   709              if node.n > 9223372036854775807 or node.n < -9223372036854775808:
   710                  self.write(".0")
   711  
   712      def visit_Index(self, node):
   713          self.visit(node.value)
   714  
   715      def visit_Call(self, node):
   716          func = ''
   717          if type(node.func) == ast.Attribute and node.func.attr == 'encode':
   718              self.skip("Go tests do not currently support character encoding")
   719              return
   720          if type(node.func) == ast.Attribute and node.func.attr == 'error':
   721              # This weird special case is because sometimes the tests
   722              # use r.error and sometimes they use r.error(). The java
   723              # driver only supports r.error(). Since we're coming in
   724              # from a call here, we have to prevent visit_Attribute
   725              # from emitting the parents on an r.error for us.
   726              self.visit_Attribute(node.func, emit_parens=False)
   727          elif type(node.func) == ast.Name and node.func.id == 'err':
   728              # Throw away third argument as it is not used by the Go tests
   729              # and Go does not support function overloading
   730              node.args = node.args[:2]
   731              self.visit(node.func)
   732          elif type(node.func) == ast.Name and node.func.id == 'err_regex':
   733              # Throw away third argument as it is not used by the Go tests
   734              # and Go does not support function overloading
   735              node.args = node.args[:2]
   736              self.visit(node.func)
   737          elif type(node.func) == ast.Name and node.func.id == 'fetch':
   738              if len(node.args) == 1:
   739                  node.args.append(ast.Constant(0))
   740              elif len(node.args) > 2:
   741                  node.args = node.args[:2]
   742              self.visit(node.func)
   743          else:
   744              self.visit(node.func)
   745  
   746          if type(node.func) == ast.Attribute:
   747              func = node.func.attr
   748          elif type(node.func) == ast.Name:
   749              func = node.func.id
   750  
   751          self.to_args(node.args, func, node.keywords)
   752  
   753      def visit_Dict(self, node):
   754          self.write("map[interface{}]interface{}{")
   755          for k, v in zip(node.keys, node.values):
   756              self.visit(k)
   757              self.write(": ")
   758              self.visit(v)
   759              self.write(", ")
   760          self.write("}")
   761  
   762      def visit_List(self, node):
   763          self.write("[]interface{}{")
   764          self.join(", ", node.elts)
   765          self.write("}")
   766  
   767      def visit_Tuple(self, node):
   768          self.visit_List(node)
   769  
   770      def visit_Lambda(self, node):
   771          self.write("func")
   772          self.to_args(node.args.args)
   773          self.write(" interface{} { return ")
   774          self.visit(node.body)
   775          self.write("}")
   776  
   777      def visit_Subscript(self, node):
   778          if node.slice is not None and type(node.slice.value) == ast.Constant and type(node.slice.value.value) == int:
   779              self.visit(node.value)
   780              self.write("[")
   781              self.write(str(node.slice.value.n))
   782              self.write("]")
   783          else:
   784              logger.error("While doing: %s", ast.dump(node))
   785              raise Unhandled("Only integers subscript can be converted."
   786                              " Got %s" % node.slice.value.s)
   787  
   788      def visit_ListComp(self, node):
   789          gen = node.generators[0]
   790  
   791          start = 0
   792          end = 0
   793  
   794          if type(gen.iter) == ast.Call and gen.iter.func.id.endswith('range'):
   795              # This is really a special-case hacking of [... for i in
   796              # range(i)] comprehensions that are used in the polyglot
   797              # tests sometimes. It won't handle translating arbitrary
   798              # comprehensions to Java streams.
   799  
   800              if len(gen.iter.args) == 1:
   801                  end = gen.iter.args[0].n
   802              elif len(gen.iter.args) == 2:
   803                  start = gen.iter.args[0].n
   804                  end = gen.iter.args[1].n
   805          else:
   806              # Somebody came up with a creative new use for
   807              # comprehensions in the test suite...
   808              raise Unhandled("ListComp hack couldn't handle: ", ast.dump(node))
   809  
   810          self.write("(func() []interface{} {\n")
   811          self.write("    res := []interface{}{}\n")
   812          self.write("    for iterator_ := %s; iterator_ < %s; iterator_++ {\n" % (start, end))
   813          self.write("        "); self.visit(gen.target); self.write(" := iterator_\n")
   814          self.write("        res = append(res, ")
   815          self.visit(node.elt)
   816          self.write(")\n")
   817          self.write("    }\n")
   818          self.write("    return res\n")
   819          self.write("}())")
   820  
   821      def visit_UnaryOp(self, node):
   822          opMap = {
   823              ast.USub: "-",
   824              ast.Not: "!",
   825              ast.UAdd: "+",
   826              ast.Invert: "~",
   827          }
   828          self.write(opMap[type(node.op)])
   829          self.visit(node.operand)
   830  
   831      def visit_BinOp(self, node):
   832          opMap = {
   833              ast.Add: " + ",
   834              ast.Sub: " - ",
   835              ast.Mult: " * ",
   836              ast.Div: " / ",
   837              ast.Mod: " % ",
   838          }
   839          if self.is_string_mul(node):
   840              return
   841          if self.is_array_concat(node):
   842              return
   843          if self.is_byte_array_add(node):
   844              return
   845  
   846          t = type(node.op)
   847          if t in opMap.keys():
   848              self.visit(node.left)
   849              self.write(opMap[t])
   850              self.visit(node.right)
   851          elif t == ast.Pow:
   852              if type(node.left) == ast.Constant and type(node.left.value) == int and node.left.n == 2:
   853                  self.visit(node.left)
   854                  self.write(" << ")
   855                  self.visit(node.right)
   856              else:
   857                  raise Unhandled("Can't do exponent with non 2 base")
   858  
   859      def is_byte_array_add(self, node):
   860          '''Some places we do stuff like b'foo' + b'bar' and byte
   861          arrays don't like that much'''
   862          if (type(node.left) == ast.Bytes and
   863             type(node.right) == ast.Bytes and
   864             type(node.op) == ast.Add):
   865              self.visit_Bytes(node.left, skip_suffix=True)
   866              self.visit_Bytes(node.right, skip_prefix=True)
   867              return True
   868          else:
   869              return False
   870  
   871      def is_string_mul(self, node):
   872          if ((type(node.left) == ast.Constant and type(node.left.value) == str and type(node.right) == ast.Constant and type(node.right.value) == int) and type(node.op) == ast.Mult):
   873              self.write("\"")
   874              self.write(node.left.s * node.right.n)
   875              self.write("\"")
   876              return True
   877          elif ((type(node.left) == ast.Constant and type(node.left.value) == int and type(node.right) == ast.Constant and type(node.right.value) == str) and type(node.op) == ast.Mult):
   878              self.write("\"")
   879              self.write(node.left.n * node.right.s)
   880              self.write("\"")
   881              return True
   882          else:
   883              return False
   884  
   885      def is_array_concat(self, node):
   886          if ((type(node.left) == ast.List or type(node.right) == ast.List) and type(node.op) == ast.Add):
   887              self.skip("Array concatenation using + operator not currently supported")
   888              return True
   889          else:
   890              return False
   891  
   892  
   893  class ReQLVisitor(GoVisitor):
   894      '''Mostly the same as the GoVisitor, but converts some
   895      reql-specific stuff. This should only be invoked on an expression
   896      if it's already known to return true from is_reql'''
   897  
   898      TOPLEVEL_CONSTANTS = {
   899          'error'
   900      }
   901  
   902      def visit_BinOp(self, node):
   903          if self.is_string_mul(node):
   904              return
   905          if self.is_array_concat(node):
   906              return
   907          if self.is_byte_array_add(node):
   908              return
   909          opMap = {
   910              ast.Add: "Add",
   911              ast.Sub: "Sub",
   912              ast.Mult: "Mul",
   913              ast.Div: "Div",
   914              ast.Mod: "Mod",
   915              ast.BitAnd: "And",
   916              ast.BitOr: "Or",
   917          }
   918          func = opMap[type(node.op)]
   919          if self.is_not_reql(node.left):
   920              self.prefix(func, node.left, node.right)
   921          else:
   922              self.infix(func, node.left, node.right)
   923  
   924      def visit_Compare(self, node):
   925          opMap = {
   926              ast.Lt: "Lt",
   927              ast.Gt: "Gt",
   928              ast.GtE: "Ge",
   929              ast.LtE: "Le",
   930              ast.Eq: "Eq",
   931              ast.NotEq: "Ne",
   932          }
   933          if len(node.ops) != 1:
   934              # Python syntax allows chained comparisons (a < b < c) but
   935              # we don't deal with that here
   936              raise Unhandled("Compare hack bailed on: ", ast.dump(node))
   937          left = node.left
   938          right = node.comparators[0]
   939          func_name = opMap[type(node.ops[0])]
   940          if self.is_not_reql(node.left):
   941              self.prefix(func_name, left, right)
   942          else:
   943              self.infix(func_name, left, right)
   944  
   945      def prefix(self, func_name, left, right):
   946          self.write("r.")
   947          self.write(func_name)
   948          self.write("(")
   949          self.visit(left)
   950          self.write(", ")
   951          self.visit(right)
   952          self.write(")")
   953  
   954      def infix(self, func_name, left, right):
   955          self.visit(left)
   956          self.write(".")
   957          self.write(func_name)
   958          self.write("(")
   959          self.visit(right)
   960          self.write(")")
   961  
   962      def is_not_reql(self, node):
   963          return type(node) in (ast.Constant, ast.Name, ast.NameConstant, ast.Dict, ast.List)
   964  
   965      def visit_Subscript(self, node):
   966          self.visit(node.value)
   967          if type(node.slice) == ast.Index:
   968              # Syntax like a[2] or a["b"]
   969              if self.smart_bracket and type(node.slice.value) == ast.Constant and type(node.slice.value.value) == str:
   970                  self.write(".Field(")
   971              elif self.smart_bracket and type(node.slice.value) == ast.Constant and type(node.slice.value.value) == int:
   972                  self.write(".Nth(")
   973              else:
   974                  self.write(".AtIndex(")
   975              self.visit(node.slice.value)
   976              self.write(")")
   977          elif type(node.slice) == ast.Slice:
   978              # Syntax like a[1:2] or a[:2]
   979              self.write(".Slice(")
   980              lower, upper, rclosed = self.get_slice_bounds(node.slice)
   981              self.write(str(lower))
   982              self.write(", ")
   983              self.write(str(upper))
   984              if rclosed:
   985                  self.write(', r.SliceOpts{RightBound: "closed"}')
   986              self.write(")")
   987          else:
   988              raise Unhandled("No translation for ExtSlice")
   989  
   990      def get_slice_bounds(self, slc):
   991          '''Used to extract bounds when using bracket slice
   992          syntax. This is more complicated since Python3 parses -1 as
   993          UnaryOp(op=USub, operand=Num(1)) instead of Num(-1) like
   994          Python2 does'''
   995          if not slc:
   996              return 0, -1, True
   997  
   998          def get_bound(bound, default):
   999              if bound is None:
  1000                  return default
  1001              elif type(bound) == ast.UnaryOp and type(bound.op) == ast.USub:
  1002                  return -bound.operand.n
  1003              elif type(bound) == ast.Constant and type(bound.value) == int:
  1004                  return bound.n
  1005              else:
  1006                  raise Unhandled(
  1007                      "Not handling bound: %s" % ast.dump(bound))
  1008  
  1009          right_closed = slc.upper is None
  1010  
  1011          return get_bound(slc.lower, 0), get_bound(slc.upper, -1), right_closed
  1012  
  1013      def convertTermName(term):
  1014          python_clashes = {
  1015              # These are underscored in the python driver to avoid
  1016              # keywords, but they aren't java keywords so we convert
  1017              # them back.
  1018              'or_': 'Or',
  1019              'and_': 'And',
  1020              'not_': 'Not',
  1021          }
  1022          method_aliases = {
  1023              'get_field': 'Field',
  1024              'db': 'DB',
  1025              'db_create': 'DBCreate',
  1026              'db_drop': 'DBDrop',
  1027              'db_list': 'DBList',
  1028              'uuid': 'UUID',
  1029              'geojson': 'GeoJSON',
  1030              'js': 'JS',
  1031              'json': 'JSON',
  1032              'to_json': 'ToJSON',
  1033              'to_json_string': 'ToJSON',
  1034              'minval': 'MinVal',
  1035              'maxval': 'MaxVal',
  1036              'http': 'HTTP',
  1037              'iso8601': 'ISO8601',
  1038              'to_iso8601': 'ToISO8601',
  1039          }
  1040  
  1041          return python_clashes.get(term, method_aliases.get(term, camel(term)))
  1042  
  1043      def visit_Attribute(self, node, emit_parens=True):
  1044          is_toplevel_constant = False
  1045          # if attr_matches("r.row", node):
  1046          # elif is_name("r", node.value) and node.attr in self.TOPLEVEL_CONSTANTS:
  1047          if is_name("r", node.value) and node.attr in self.TOPLEVEL_CONSTANTS:
  1048              # Python has r.minval, r.saturday etc. We need to emit
  1049              # r.minval() and r.saturday()
  1050              is_toplevel_constant = True
  1051  
  1052          initial = ReQLVisitor.convertTermName(node.attr)
  1053  
  1054          self.visit(node.value)
  1055          self.write(".")
  1056          self.write(initial)
  1057          if initial in GO_KEYWORDS:
  1058              self.write('_')
  1059          if emit_parens and is_toplevel_constant:
  1060              self.write('()')
  1061  
  1062      def visit_UnaryOp(self, node):
  1063          if type(node.op) == ast.Invert:
  1064              self.visit(node.operand)
  1065              self.write(".Not()")
  1066          else:
  1067              super(ReQLVisitor, self).visit_UnaryOp(node)
  1068  
  1069      def visit_Call(self, node):
  1070          if (attr_equals(node.func, "attr", "index_create") and len(node.args) == 2):
  1071              node.func.attr = 'index_create_func'
  1072  
  1073          # We call the superclass first, so if it's going to fail
  1074          # because of r.row or other things it fails first, rather than
  1075          # hitting the checks in this method. Since everything is
  1076          # written to a stringIO object not directly to a file, if we
  1077          # bail out afterwards it's still ok
  1078          super_result = super(ReQLVisitor, self).visit_Call(node)
  1079  
  1080          # r.expr(v, 1) should be skipped
  1081          if (attr_equals(node.func, "attr", "expr") and len(node.args) > 1):
  1082              self.skip("the go driver only accepts one parameter to expr")
  1083          # r.table_create("a", "b") should be skipped
  1084          if (attr_equals(node.func, "attr", "table_create") and len(node.args) > 1):
  1085              self.skip("the go driver only accepts one parameter to table_create")
  1086          return super_result
  1087  
  1088  class Renderer(object):
  1089      '''Manages rendering templates'''
  1090  
  1091      def __init__(self, invoking_filenames, source_files=None):
  1092          self.template_file = './template.go.tpl'
  1093          self.invoking_filenames = invoking_filenames
  1094          self.source_files = source_files or []
  1095          self.tpl = Template(filename=self.template_file)
  1096          self.template_context = {
  1097              'EmptyTemplate': process_polyglot.EmptyTemplate,
  1098          }
  1099  
  1100      def render(self,
  1101                 template_name,
  1102                 output_dir,
  1103                 output_name=None,
  1104                 **kwargs):
  1105          if output_name is None:
  1106              output_name = template_name
  1107  
  1108          output_path = output_dir + '/' + output_name
  1109  
  1110          results = self.template_context.copy()
  1111          results.update(kwargs)
  1112          try:
  1113              rendered = self.tpl.render(**results)
  1114          except process_polyglot.EmptyTemplate:
  1115              logger.debug(" Empty template: %s", output_path)
  1116              return
  1117          with codecs.open(output_path, "w", "utf-8") as outfile:
  1118              logger.info("Rendering %s", output_path)
  1119              outfile.write(self.autogenerated_header(
  1120                  self.template_file,
  1121                  output_path,
  1122                  self.invoking_filenames,
  1123              ))
  1124              outfile.write(rendered)
  1125  
  1126      def autogenerated_header(self, template_path, output_path, filename):
  1127          rel_tpl = os.path.relpath(template_path, start=output_path)
  1128          filenames = ' and '.join(os.path.basename(f)
  1129                                   for f in self.invoking_filenames)
  1130          return ('// Code generated by {}.\n'
  1131                  '// Do not edit this file directly.\n'
  1132                  '// The template for this file is located at:\n'
  1133                  '// {}\n').format(filenames, rel_tpl)
  1134  
  1135  def camel(varname):
  1136      'CamelCase'
  1137      if re.match(r'[A-Z][A-Z0-9_]*$|[a-z][a-z0-9_]*$', varname):
  1138          # if snake-case (upper or lower) camelize it
  1139          suffix = "_" if varname.endswith('_') else ""
  1140          return ''.join(x.title() for x in varname.split('_')) + suffix
  1141      else:
  1142          # if already mixed case, just capitalize the first letter
  1143          return varname[0].upper() + varname[1:]
  1144  
  1145  
  1146  def dromedary(varname):
  1147      'dromedaryCase'
  1148      if re.match(r'[A-Z][A-Z0-9_]*$|[a-z][a-z0-9_]*$', varname):
  1149          chunks = varname.split('_')
  1150          suffix = "_" if varname.endswith('_') else ""
  1151          return (chunks[0].lower() +
  1152                  ''.join(x.title() for x in chunks[1:]) +
  1153                  suffix)
  1154      else:
  1155          return varname[0].lower() + varname[1:]
  1156  
  1157  def attr_equals(node, attr, value):
  1158      '''Helper for digging into ast nodes'''
  1159      return hasattr(node, attr) and getattr(node, attr) == value
  1160  
  1161  if __name__ == '__main__':
  1162      main()