gopkg.in/rethinkdb/rethinkdb-go.v6@v6.2.2/internal/gen_tests/gen_tests.py (about) 1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 '''Finds yaml tests, converts them to Go tests.''' 4 from __future__ import print_function 5 6 import sys 7 import os 8 import os.path 9 import re 10 import time 11 import ast 12 from re import sub, match, split, DOTALL 13 import argparse 14 import codecs 15 import logging 16 import process_polyglot 17 import parse_polyglot 18 from process_polyglot import Unhandled, Skip, FatalSkip, SkippedTest 19 from mako.template import Template 20 21 try: 22 from cStringIO import StringIO 23 except ImportError: 24 from io import StringIO 25 from collections import namedtuple 26 27 parse_polyglot.printDebug = False 28 logger = logging.getLogger("convert_tests") 29 r = None 30 31 32 # Individual tests can be ignored by specifying their line numbers from the YAML source files. 33 TEST_EXCLUSIONS = { 34 'regression/': None, 35 36 'limits.yaml': [ 37 36, # Produces invalid Go code using `list` and `range` 38 39, 41, 47, # Rely on above 39 75, 87, 108, 120, # The fetch() calls hang 40 ], 41 42 'changefeeds/squash': [ 43 47, # double run 44 ], 45 46 'arity': [ 47 51, 100, 155, 272, 282, 291, # Malformed syntax produces malformed Go 48 ], 49 50 'sindex/truncation.rb.yaml': None, # The code generator fails when it encounters (0...n).map{} 51 52 'changefeeds/geo.rb.yaml': None, # The generated code needs to enclose the map keys inside quotation marks 53 54 'aggregation.yaml': [ 55 482, # A nil map key is expected, but Go strings cannot be nil, so actual result is "" 56 ], 57 58 # This file expects a password-less user `test_user` to exist, but even if it does, 59 # - Lines #18 and #27 try to create a database and fail if it already exists, but the YAML doesn't drop the database 60 # afterwards, so successive runs always fail. Unfortunately, other test cases depend on these lines being run. 61 # - Lines #49, #54, #60, and #66 use a global run option `user` which Rethink doesn't recognize. 62 'meta/grant.yaml': None, 63 64 # The lambdas passed to SetWriteHook() must return rethinkdb.Term but the generated code uses interface{} instead. 65 # These types of queries might be covered by reql_writehook_test.go though. 66 'mutation/write_hook.yaml': None, 67 68 'changefeeds/idxcopy.yaml': [ 69 28, # The fetch() hangs. 70 ], 71 72 'changefeeds/now.py_one.yaml': [ 73 5, # Hangs instead of returning an error. 74 ], 75 76 'times/constructors.yaml': [ 77 59, 61, # Tries to create year 10000 but Rethink now only allows up to 9999. 78 64, 70, # The expected error message refers to 10000 but Rethink's error actually says 9999. 79 ], 80 81 'math_logic/logic.yaml': [ 82 143, 147, # Expected an error but got nil. Don't know if it's a real problem. 83 ], 84 85 'control.yaml': [ 86 # Attempts to add 1 to a integer variable that was set to tbl.Count(). The Go '+' operator must be used instead of the ReQL '.Add()'. 87 # Deciding which operator to use requires the variable type, which in turn requires knowing what type each ReQL function returns. 88 # The robust solution is too much work, and checking for '.Count()' in the definition is a hack. 89 236, 90 ], 91 92 'meta/dbs.yaml': [ 93 # For some reason the results include databases created by other tests (eg. 'examples', 'test_cur_all') 94 6, 18, 31, 37, 95 ], 96 97 # For some reason the results include databases created by other tests (eg. 'examples', 'test_cur_all') 98 'meta/composite.py.yaml': None, 99 } 100 101 GO_KEYWORDS = [ 102 'break', 'case', 'chan', 'const', 'continue', 'default', 'defer', 'else', 103 'fallthrough', 'for', 'func', 'go', 'goto', 'if', 'import', 'interface', 104 'map', 'package', 'range', 'return', 'select', 'struct', 'switch', 'type', 105 'var' 106 ] 107 108 def main(): 109 logging.basicConfig(format="[%(name)s] %(message)s", level=logging.INFO) 110 start = time.perf_counter() 111 args = parse_args() 112 if args.debug: 113 logger.setLevel(logging.DEBUG) 114 logging.getLogger('process_polyglot').setLevel(logging.DEBUG) 115 elif args.info: 116 logger.setLevel(logging.INFO) 117 logging.getLogger('process_polyglot').setLevel(logging.INFO) 118 else: 119 logger.root.setLevel(logging.WARNING) 120 if args.e: 121 evaluate_snippet(args.e) 122 exit(0) 123 global r 124 r = import_python_driver(args.python_driver_dir) 125 renderer = Renderer( 126 invoking_filenames=[ 127 __file__, 128 process_polyglot.__file__, 129 ]) 130 full_exclusions = list(file for (file, lines) in TEST_EXCLUSIONS.items() if not lines) 131 for testfile in process_polyglot.all_yaml_tests( 132 args.test_dir, 133 full_exclusions): 134 logger.info("Working on %s", testfile) 135 136 excluded_lines = set() 137 for exclusion, lines in TEST_EXCLUSIONS.items(): 138 if exclusion in testfile: 139 excluded_lines.update(lines) 140 141 TestFile( 142 test_dir=args.test_dir, 143 filename=testfile, 144 excluded_lines=excluded_lines, 145 test_output_dir=args.test_output_dir, 146 renderer=renderer, 147 ).load().render() 148 logger.info("Finished in %s seconds", time.perf_counter() - start) 149 150 151 def parse_args(): 152 '''Parse command line arguments''' 153 parser = argparse.ArgumentParser(description=__doc__) 154 parser.add_argument( 155 "--test-dir", 156 help="Directory where yaml tests are", 157 default="" 158 ) 159 parser.add_argument( 160 "--test-output-dir", 161 help="Directory to render tests to", 162 default=".", 163 ) 164 parser.add_argument( 165 "--python-driver-dir", 166 help="Where the built python driver is located", 167 default="" 168 ) 169 parser.add_argument( 170 "--test-file", 171 help="Only convert the specified yaml file", 172 ) 173 parser.add_argument( 174 '--debug', 175 help="Print debug output", 176 dest='debug', 177 action='store_true') 178 parser.set_defaults(debug=False) 179 parser.add_argument( 180 '--info', 181 help="Print info level output", 182 dest='info', 183 action='store_true') 184 parser.set_defaults(info=False) 185 parser.add_argument( 186 '-e', 187 help="Convert an inline python reql to go reql snippet", 188 ) 189 return parser.parse_args() 190 191 192 def import_python_driver(py_driver_dir): 193 '''Imports the test driver header''' 194 stashed_path = sys.path 195 sys.path.insert(0, os.path.realpath(py_driver_dir)) 196 from rethinkdb import r 197 sys.path = stashed_path 198 return r 199 200 GoQuery = namedtuple( 201 'GoQuery', 202 ('is_value', 203 'line', 204 'expected_type', 205 'expected_line', 206 'testfile', 207 'line_num', 208 'runopts') 209 ) 210 GoDef = namedtuple( 211 'GoDef', 212 ('line', 213 'varname', 214 'vartype', 215 'value', 216 'run_if_query', 217 'testfile', 218 'line_num', 219 'runopts') 220 ) 221 Version = namedtuple("Version", "original go") 222 223 GO_DECL = re.compile(r'var (?P<var>\w+) (?P<type>.+) = (?P<value>.*)') 224 225 226 def evaluate_snippet(snippet): 227 '''Just converts a single expression snippet into java''' 228 try: 229 parsed = ast.parse(snippet, mode='eval').body 230 except Exception as e: 231 return print("Error:", e) 232 try: 233 print(ReQLVisitor(smart_bracket=True).convert(parsed)) 234 except Exception as e: 235 return print("Error:", e) 236 237 238 class TestFile(object): 239 '''Represents a single test file''' 240 241 def __init__(self, test_dir, filename, excluded_lines, test_output_dir, renderer): 242 self.filename = filename 243 self.full_path = os.path.join(test_dir, filename) 244 self.excluded_lines = excluded_lines 245 self.module_name = filename.split('.')[0].replace('/', '_') 246 self.test_output_dir = test_output_dir 247 self.reql_vars = {'r'} 248 self.renderer = renderer 249 250 def load(self): 251 '''Load the test file, yaml parse it, extract file-level metadata''' 252 with open(self.full_path, encoding='utf-8') as f: 253 parsed_yaml = parse_polyglot.parseYAML(f) 254 self.description = parsed_yaml.get('desc', 'No description') 255 self.table_var_names = self.get_varnames(parsed_yaml) 256 self.reql_vars.update(self.table_var_names) 257 self.raw_test_data = parsed_yaml['tests'] 258 self.test_generator = process_polyglot.tests_and_defs( 259 self.filename, 260 self.raw_test_data, 261 context=process_polyglot.create_context(r, self.table_var_names), 262 custom_field='go', 263 ) 264 return self 265 266 def get_varnames(self, yaml_file): 267 '''Extract table variable names from yaml variable 268 They can be specified just space separated, or comma separated''' 269 raw_var_names = yaml_file.get('table_variable_name', '') 270 if not raw_var_names: 271 return set() 272 return set(re.split(r'[, ]+', raw_var_names)) 273 274 def render(self): 275 '''Renders the converted tests to a runnable test file''' 276 defs_and_test = ast_to_go(self.test_generator, self.reql_vars, self.excluded_lines) 277 self.renderer.source_files = [self.full_path] 278 self.renderer.render( 279 'template.go', 280 output_dir=self.test_output_dir, 281 output_name='reql_' + self.module_name + '_test.go', 282 dependencies=[self.full_path], 283 defs_and_test=defs_and_test, 284 table_var_names=list(sorted(self.table_var_names)), 285 module_name=camel(self.module_name), 286 GoQuery=GoQuery, 287 GoDef=GoDef, 288 description=self.description, 289 ) 290 291 292 def py_to_go_type(py_type): 293 '''Converts python types to their Go equivalents''' 294 if py_type is None: 295 return None 296 elif isinstance(py_type, str): 297 # This can be called on something already converted 298 return py_type 299 elif py_type.__name__ == 'function': 300 return 'func()' 301 elif (py_type.__module__ == 'datetime' and 302 py_type.__name__ == 'datetime'): 303 return 'time.Time' 304 elif py_type.__module__ == 'builtins': 305 if py_type.__name__.endswith('Error'): 306 return 'error' 307 308 return { 309 bool: 'bool', 310 bytes: '[]byte', 311 int: 'int', 312 float: 'float64', 313 str: 'string', 314 dict: 'map[interface{}]interface{}', 315 list: '[]interface{}', 316 object: 'map[interface{}]interface{}', 317 type(None): 'interface{}', 318 }[py_type] 319 elif py_type.__module__ == 'rethinkdb.ast': 320 return "r.Term" 321 # Anomalous non-rule based capitalization in the python driver 322 # return { 323 # }.get(py_type.__name__, py_type.__name__) 324 elif py_type.__module__ == 'rethinkdb.errors': 325 return py_type.__name__ 326 elif py_type.__module__ == '?test?': 327 return { 328 'int_cmp': 'int', 329 'float_cmp': 'float64', 330 'err_regex': 'Err', 331 'partial': 'compare.Expected', 332 'bag': 'compare.Expected', 333 'uuid': 'compare.Regex', # clashes with ast.Uuid 334 }.get(py_type.__name__, camel(py_type.__name__)) 335 elif py_type.__module__ == 'rethinkdb.query': 336 # ReQL constants don't have a type; they are just identifiers. 337 return None 338 else: 339 raise Unhandled( 340 "Don't know how to convert python type {}.{} to Go" 341 .format(py_type.__module__, py_type.__name__)) 342 343 344 def def_to_go(item, reql_vars): 345 if is_reql(item.term.type): 346 reql_vars.add(item.varname) 347 try: 348 if is_reql(item.term.type): 349 visitor = ReQLVisitor 350 else: 351 visitor = GoVisitor 352 go_line = visitor(reql_vars, 353 type_=item.term.type, 354 is_def=True, 355 ).convert(item.term.ast) 356 except Skip as skip: 357 return SkippedTest(line=item.term.line, reason=str(skip)) 358 go_decl = GO_DECL.match(go_line).groupdict() 359 return GoDef( 360 line=Version( 361 original=item.term.line, 362 go=go_line, 363 ), 364 varname=go_decl['var'], 365 vartype=go_decl['type'], 366 value=go_decl['value'], 367 run_if_query=item.run_if_query, 368 testfile=item.testfile, 369 line_num=item.line_num, 370 runopts=convert_runopts(reql_vars, go_decl['type'], item.runopts) 371 ) 372 373 374 def query_to_go(item, reql_vars): 375 if item.runopts is not None: 376 converted_runopts = convert_runopts( 377 reql_vars, item.query.type, item.runopts) 378 else: 379 converted_runopts = item.runopts 380 if converted_runopts is None: 381 converted_runopts = {} 382 converted_runopts['GeometryFormat'] = '\"raw\"' 383 if 'GroupFormat' not in converted_runopts: 384 converted_runopts['GroupFormat'] = '\"map\"' 385 try: 386 is_value = False 387 if not item.query.type.__module__.startswith("rethinkdb.") or (not item.query.type.__module__.startswith("rethinkdb.") and item.query.type.__name__.endswith("Error")): 388 is_value = True 389 go_line = ReQLVisitor( 390 reql_vars, type_=item.query.type).convert(item.query.ast) 391 if is_reql(item.expected.type): 392 visitor = ReQLVisitor 393 else: 394 visitor = GoVisitor 395 go_expected_line = visitor( 396 reql_vars, type_=item.expected.type)\ 397 .convert(item.expected.ast) 398 except Skip as skip: 399 return SkippedTest(line=item.query.line, reason=str(skip)) 400 return GoQuery( 401 is_value=is_value, 402 line=Version( 403 original=item.query.line, 404 go=go_line, 405 ), 406 expected_type=py_to_go_type(item.expected.type), 407 expected_line=Version( 408 original=item.expected.line, 409 go=go_expected_line, 410 ), 411 testfile=item.testfile, 412 line_num=item.line_num, 413 runopts=converted_runopts, 414 ) 415 416 417 def ast_to_go(sequence, reql_vars, excluded_lines): 418 '''Converts the the parsed test data to go source lines using the 419 visitor classes''' 420 reql_vars = set(reql_vars) 421 for item in sequence: 422 if hasattr(item, 'line_num') and item.line_num in excluded_lines: 423 logger.info("Skipped %s line %d due to exclusion", item.testfile, item.line_num) 424 elif type(item) == process_polyglot.Def: 425 yield def_to_go(item, reql_vars) 426 elif type(item) == process_polyglot.CustomDef: 427 yield GoDef(line=Version(item.line, item.line), 428 testfile=item.testfile, 429 line_num=item.line_num) 430 elif type(item) == process_polyglot.Query: 431 yield query_to_go(item, reql_vars) 432 elif type(item) == SkippedTest: 433 yield item 434 else: 435 assert False, "shouldn't happen, item was {}".format(item) 436 437 438 def is_reql(t): 439 '''Determines if a type is a reql term''' 440 # Other options for module: builtins, ?test?, datetime 441 if not hasattr(t, '__module__'): 442 return True 443 444 return t.__module__ == 'rethinkdb.ast' 445 446 447 def escape_string(s, out): 448 out.write('"') 449 for codepoint in s: 450 rpr = repr(codepoint)[1:-1] 451 if rpr.startswith('\\x'): 452 rpr = '\\u00' + rpr[2:] 453 elif rpr == '"': 454 rpr = r'\"' 455 out.write(rpr) 456 out.write('"') 457 458 459 def attr_matches(path, node): 460 '''Helper function. Several places need to know if they are an 461 attribute of some root object''' 462 root, name = path.split('.') 463 ret = is_name(root, node.value) and node.attr == name 464 return ret 465 466 467 def is_name(name, node): 468 '''Determine if the current attribute node is a Name with the 469 given name''' 470 return type(node) == ast.Name and node.id == name 471 472 473 def convert_runopts(reql_vars, type_, runopts): 474 if runopts is None: 475 return None 476 return { 477 camel(key): GoVisitor( 478 reql_vars, type_=type_).convert(val) 479 for key, val in runopts.items() 480 } 481 482 483 class GoVisitor(ast.NodeVisitor): 484 '''Converts python ast nodes into a Go string''' 485 486 def __init__(self, 487 reql_vars=frozenset("r"), 488 out=None, 489 type_=None, 490 is_def=False, 491 smart_bracket=False, 492 ): 493 self.out = StringIO() if out is None else out 494 self.reql_vars = reql_vars 495 self.type = py_to_go_type(type_) 496 self._type = type_ 497 self.is_def = is_def 498 self.smart_bracket = smart_bracket 499 super(GoVisitor, self).__init__() 500 self.write = self.out.write 501 502 def skip(self, message, *args, **kwargs): 503 cls = Skip 504 is_fatal = kwargs.pop('fatal', False) 505 if self.is_def or is_fatal: 506 cls = FatalSkip 507 raise cls(message, *args, **kwargs) 508 509 def convert(self, node): 510 '''Convert a text line to another text line''' 511 self.visit(node) 512 return self.out.getvalue() 513 514 def join(self, sep, items): 515 first = True 516 for item in items: 517 if first: 518 first = False 519 else: 520 self.write(sep) 521 self.visit(item) 522 523 def to_str(self, s): 524 escape_string(s, self.out) 525 526 def cast_null(self, arg): 527 if (type(arg) == ast.Name and arg.id == 'null') or \ 528 (type(arg) == ast.NameConstant and arg.value == None): 529 self.write("nil") 530 else: 531 self.visit(arg) 532 533 def to_args(self, args, func='', optargs=[]): 534 optargs_first = False 535 536 self.write("(") 537 if args: 538 self.cast_null(args[0]) 539 for arg in args[1:]: 540 self.write(', ') 541 self.cast_null(arg) 542 self.write(")") 543 544 if optargs: 545 self.to_args_optargs(func, optargs) 546 547 def to_args_optargs(self, func='', optargs=[]): 548 optarg_aliases = { 549 'JsOpts': 'JSOpts', 550 'HttpOpts': 'HTTPOpts', 551 'Iso8601Opts': 'ISO8601Opts', 552 'IndexCreateFuncOpts': 'IndexCreateOpts' 553 } 554 optarg_field_aliases = { 555 'nonvoting_replica_tags': 'NonVotingReplicaTags', 556 } 557 558 if not func: 559 raise Unhandled("Missing function name") 560 561 optarg_type = camel(func) + 'Opts' 562 optarg_type = optarg_aliases.get(optarg_type, optarg_type) 563 optarg_type = 'r.' + optarg_type 564 565 self.write('.OptArgs(') 566 self.write(optarg_type) 567 self.write('{') 568 for optarg in optargs: 569 # Hack to skip tests that check for unknown opt args, 570 # this is not possible in Go due to static types 571 if optarg.arg == 'return_vals': 572 self.skip("test not required since optargs are statically typed") 573 return 574 if optarg.arg == 'foo': 575 self.skip("test not required since optargs are statically typed") 576 return 577 if type(optarg.value) == ast.Name and optarg.value.id == 'null': 578 self.skip("test not required since go does not support null optargs") 579 return 580 581 582 field_name = optarg_field_aliases.get(optarg.arg, camel(optarg.arg)) 583 584 self.write(field_name) 585 self.write(": ") 586 self.cast_null(optarg.value) 587 self.write(', ') 588 self.write('})') 589 590 def generic_visit(self, node): 591 logger.error("While translating: %s", ast.dump(node)) 592 logger.error("Got as far as: %s", ''.join(self.out)) 593 raise Unhandled("Don't know what this thing is: " + str(type(node))) 594 595 def visit_Assign(self, node): 596 if len(node.targets) != 1: 597 Unhandled("We only support assigning to one variable") 598 var = node.targets[0].id 599 self.write("var " + var + " ") 600 601 if is_reql(self._type): 602 self.write('r.Term') 603 else: 604 self.write(self.type) 605 self.write(" = ") 606 if is_reql(self._type): 607 ReQLVisitor(self.reql_vars, 608 out=self.out, 609 type_=self.type, 610 is_def=True, 611 ).visit(node.value) 612 elif var == 'upper_limit': # Manually set value since value in test causes an error 613 self.write('2<<52 - 1') 614 elif var == 'lower_limit': # Manually set value since value in test causes an error 615 self.write('1 - 2<<52') 616 else: 617 self.visit(node.value) 618 619 def visit_Str(self, node): 620 if node.s == 'ReqlServerCompileError': 621 node.s = 'ReqlCompileError' 622 # Hack to skip irrelevant tests 623 if match(".*Expected .* argument", node.s): 624 self.skip("argument checks not supported") 625 if match(".*argument .* must", node.s): 626 self.skip("argument checks not supported") 627 return 628 if node.s == 'Object keys must be strings.*': 629 self.skip('the Go driver automatically converts object keys to strings') 630 return 631 if node.s.startswith('\'module\' object has no attribute '): 632 self.skip('test not required since terms are statically typed') 633 return 634 self.to_str(node.s) 635 636 def visit_Bytes(self, node, skip_prefix=False, skip_suffix=False): 637 if not skip_prefix: 638 self.write("[]byte{") 639 for i, byte in enumerate(node.s): 640 self.write(str(byte)) 641 if i < len(node.s)-1: 642 self.write(",") 643 if not skip_suffix: 644 self.write("}") 645 else: 646 self.write(", ") 647 648 def visit_Name(self, node): 649 name = node.id 650 if name == 'frozenset': 651 self.skip("can't convert frozensets to GroupedData yet") 652 if name in GO_KEYWORDS: 653 name += '_' 654 self.write({ 655 'True': 'true', 656 'False': 'false', 657 'None': 'nil', 658 'nil': 'nil', 659 'null': 'nil', 660 'float': 'float64', 661 'datetime': 'Ast', # Use helper method instead 662 'len': 'maybeLen', 663 'AnythingIsFine': 'compare.AnythingIsFine', 664 'bag': 'compare.UnorderedMatch', 665 'partial': 'compare.PartialMatch', 666 'uuid': 'compare.IsUUID', 667 'regex': 'compare.MatchesRegexp', 668 }.get(name, name)) 669 670 def visit_arg(self, node): 671 self.write(node.arg) 672 self.write(" ") 673 if is_reql(self._type): 674 self.write("r.Term") 675 else: 676 self.write("interface{}") 677 678 def visit_NameConstant(self, node): 679 if node.value is None: 680 self.write("nil") 681 elif node.value is True: 682 self.write("true") 683 elif node.value is False: 684 self.write("false") 685 else: 686 raise Unhandled( 687 "Don't know NameConstant with value %s" % node.value) 688 689 def visit_Attribute(self, node, emit_parens=True): 690 skip_parent = False 691 if attr_matches("r.ast", node): 692 # The java driver doesn't have that namespace, so we skip 693 # the `r.` prefix and create an ast class member in the 694 # test file. So stuff like `r.ast.rqlTzinfo(...)` converts 695 # to `ast.rqlTzinfo(...)` 696 skip_parent = True 697 698 699 if not skip_parent: 700 self.visit(node.value) 701 self.write(".") 702 703 attr = ReQLVisitor.convertTermName(node.attr) 704 self.write(attr) 705 706 def visit_Num(self, node): 707 self.write(repr(node.n)) 708 if not isinstance(node.n, float): 709 if node.n > 9223372036854775807 or node.n < -9223372036854775808: 710 self.write(".0") 711 712 def visit_Index(self, node): 713 self.visit(node.value) 714 715 def visit_Call(self, node): 716 func = '' 717 if type(node.func) == ast.Attribute and node.func.attr == 'encode': 718 self.skip("Go tests do not currently support character encoding") 719 return 720 if type(node.func) == ast.Attribute and node.func.attr == 'error': 721 # This weird special case is because sometimes the tests 722 # use r.error and sometimes they use r.error(). The java 723 # driver only supports r.error(). Since we're coming in 724 # from a call here, we have to prevent visit_Attribute 725 # from emitting the parents on an r.error for us. 726 self.visit_Attribute(node.func, emit_parens=False) 727 elif type(node.func) == ast.Name and node.func.id == 'err': 728 # Throw away third argument as it is not used by the Go tests 729 # and Go does not support function overloading 730 node.args = node.args[:2] 731 self.visit(node.func) 732 elif type(node.func) == ast.Name and node.func.id == 'err_regex': 733 # Throw away third argument as it is not used by the Go tests 734 # and Go does not support function overloading 735 node.args = node.args[:2] 736 self.visit(node.func) 737 elif type(node.func) == ast.Name and node.func.id == 'fetch': 738 if len(node.args) == 1: 739 node.args.append(ast.Constant(0)) 740 elif len(node.args) > 2: 741 node.args = node.args[:2] 742 self.visit(node.func) 743 else: 744 self.visit(node.func) 745 746 if type(node.func) == ast.Attribute: 747 func = node.func.attr 748 elif type(node.func) == ast.Name: 749 func = node.func.id 750 751 self.to_args(node.args, func, node.keywords) 752 753 def visit_Dict(self, node): 754 self.write("map[interface{}]interface{}{") 755 for k, v in zip(node.keys, node.values): 756 self.visit(k) 757 self.write(": ") 758 self.visit(v) 759 self.write(", ") 760 self.write("}") 761 762 def visit_List(self, node): 763 self.write("[]interface{}{") 764 self.join(", ", node.elts) 765 self.write("}") 766 767 def visit_Tuple(self, node): 768 self.visit_List(node) 769 770 def visit_Lambda(self, node): 771 self.write("func") 772 self.to_args(node.args.args) 773 self.write(" interface{} { return ") 774 self.visit(node.body) 775 self.write("}") 776 777 def visit_Subscript(self, node): 778 if node.slice is not None and type(node.slice.value) == ast.Constant and type(node.slice.value.value) == int: 779 self.visit(node.value) 780 self.write("[") 781 self.write(str(node.slice.value.n)) 782 self.write("]") 783 else: 784 logger.error("While doing: %s", ast.dump(node)) 785 raise Unhandled("Only integers subscript can be converted." 786 " Got %s" % node.slice.value.s) 787 788 def visit_ListComp(self, node): 789 gen = node.generators[0] 790 791 start = 0 792 end = 0 793 794 if type(gen.iter) == ast.Call and gen.iter.func.id.endswith('range'): 795 # This is really a special-case hacking of [... for i in 796 # range(i)] comprehensions that are used in the polyglot 797 # tests sometimes. It won't handle translating arbitrary 798 # comprehensions to Java streams. 799 800 if len(gen.iter.args) == 1: 801 end = gen.iter.args[0].n 802 elif len(gen.iter.args) == 2: 803 start = gen.iter.args[0].n 804 end = gen.iter.args[1].n 805 else: 806 # Somebody came up with a creative new use for 807 # comprehensions in the test suite... 808 raise Unhandled("ListComp hack couldn't handle: ", ast.dump(node)) 809 810 self.write("(func() []interface{} {\n") 811 self.write(" res := []interface{}{}\n") 812 self.write(" for iterator_ := %s; iterator_ < %s; iterator_++ {\n" % (start, end)) 813 self.write(" "); self.visit(gen.target); self.write(" := iterator_\n") 814 self.write(" res = append(res, ") 815 self.visit(node.elt) 816 self.write(")\n") 817 self.write(" }\n") 818 self.write(" return res\n") 819 self.write("}())") 820 821 def visit_UnaryOp(self, node): 822 opMap = { 823 ast.USub: "-", 824 ast.Not: "!", 825 ast.UAdd: "+", 826 ast.Invert: "~", 827 } 828 self.write(opMap[type(node.op)]) 829 self.visit(node.operand) 830 831 def visit_BinOp(self, node): 832 opMap = { 833 ast.Add: " + ", 834 ast.Sub: " - ", 835 ast.Mult: " * ", 836 ast.Div: " / ", 837 ast.Mod: " % ", 838 } 839 if self.is_string_mul(node): 840 return 841 if self.is_array_concat(node): 842 return 843 if self.is_byte_array_add(node): 844 return 845 846 t = type(node.op) 847 if t in opMap.keys(): 848 self.visit(node.left) 849 self.write(opMap[t]) 850 self.visit(node.right) 851 elif t == ast.Pow: 852 if type(node.left) == ast.Constant and type(node.left.value) == int and node.left.n == 2: 853 self.visit(node.left) 854 self.write(" << ") 855 self.visit(node.right) 856 else: 857 raise Unhandled("Can't do exponent with non 2 base") 858 859 def is_byte_array_add(self, node): 860 '''Some places we do stuff like b'foo' + b'bar' and byte 861 arrays don't like that much''' 862 if (type(node.left) == ast.Bytes and 863 type(node.right) == ast.Bytes and 864 type(node.op) == ast.Add): 865 self.visit_Bytes(node.left, skip_suffix=True) 866 self.visit_Bytes(node.right, skip_prefix=True) 867 return True 868 else: 869 return False 870 871 def is_string_mul(self, node): 872 if ((type(node.left) == ast.Constant and type(node.left.value) == str and type(node.right) == ast.Constant and type(node.right.value) == int) and type(node.op) == ast.Mult): 873 self.write("\"") 874 self.write(node.left.s * node.right.n) 875 self.write("\"") 876 return True 877 elif ((type(node.left) == ast.Constant and type(node.left.value) == int and type(node.right) == ast.Constant and type(node.right.value) == str) and type(node.op) == ast.Mult): 878 self.write("\"") 879 self.write(node.left.n * node.right.s) 880 self.write("\"") 881 return True 882 else: 883 return False 884 885 def is_array_concat(self, node): 886 if ((type(node.left) == ast.List or type(node.right) == ast.List) and type(node.op) == ast.Add): 887 self.skip("Array concatenation using + operator not currently supported") 888 return True 889 else: 890 return False 891 892 893 class ReQLVisitor(GoVisitor): 894 '''Mostly the same as the GoVisitor, but converts some 895 reql-specific stuff. This should only be invoked on an expression 896 if it's already known to return true from is_reql''' 897 898 TOPLEVEL_CONSTANTS = { 899 'error' 900 } 901 902 def visit_BinOp(self, node): 903 if self.is_string_mul(node): 904 return 905 if self.is_array_concat(node): 906 return 907 if self.is_byte_array_add(node): 908 return 909 opMap = { 910 ast.Add: "Add", 911 ast.Sub: "Sub", 912 ast.Mult: "Mul", 913 ast.Div: "Div", 914 ast.Mod: "Mod", 915 ast.BitAnd: "And", 916 ast.BitOr: "Or", 917 } 918 func = opMap[type(node.op)] 919 if self.is_not_reql(node.left): 920 self.prefix(func, node.left, node.right) 921 else: 922 self.infix(func, node.left, node.right) 923 924 def visit_Compare(self, node): 925 opMap = { 926 ast.Lt: "Lt", 927 ast.Gt: "Gt", 928 ast.GtE: "Ge", 929 ast.LtE: "Le", 930 ast.Eq: "Eq", 931 ast.NotEq: "Ne", 932 } 933 if len(node.ops) != 1: 934 # Python syntax allows chained comparisons (a < b < c) but 935 # we don't deal with that here 936 raise Unhandled("Compare hack bailed on: ", ast.dump(node)) 937 left = node.left 938 right = node.comparators[0] 939 func_name = opMap[type(node.ops[0])] 940 if self.is_not_reql(node.left): 941 self.prefix(func_name, left, right) 942 else: 943 self.infix(func_name, left, right) 944 945 def prefix(self, func_name, left, right): 946 self.write("r.") 947 self.write(func_name) 948 self.write("(") 949 self.visit(left) 950 self.write(", ") 951 self.visit(right) 952 self.write(")") 953 954 def infix(self, func_name, left, right): 955 self.visit(left) 956 self.write(".") 957 self.write(func_name) 958 self.write("(") 959 self.visit(right) 960 self.write(")") 961 962 def is_not_reql(self, node): 963 return type(node) in (ast.Constant, ast.Name, ast.NameConstant, ast.Dict, ast.List) 964 965 def visit_Subscript(self, node): 966 self.visit(node.value) 967 if type(node.slice) == ast.Index: 968 # Syntax like a[2] or a["b"] 969 if self.smart_bracket and type(node.slice.value) == ast.Constant and type(node.slice.value.value) == str: 970 self.write(".Field(") 971 elif self.smart_bracket and type(node.slice.value) == ast.Constant and type(node.slice.value.value) == int: 972 self.write(".Nth(") 973 else: 974 self.write(".AtIndex(") 975 self.visit(node.slice.value) 976 self.write(")") 977 elif type(node.slice) == ast.Slice: 978 # Syntax like a[1:2] or a[:2] 979 self.write(".Slice(") 980 lower, upper, rclosed = self.get_slice_bounds(node.slice) 981 self.write(str(lower)) 982 self.write(", ") 983 self.write(str(upper)) 984 if rclosed: 985 self.write(', r.SliceOpts{RightBound: "closed"}') 986 self.write(")") 987 else: 988 raise Unhandled("No translation for ExtSlice") 989 990 def get_slice_bounds(self, slc): 991 '''Used to extract bounds when using bracket slice 992 syntax. This is more complicated since Python3 parses -1 as 993 UnaryOp(op=USub, operand=Num(1)) instead of Num(-1) like 994 Python2 does''' 995 if not slc: 996 return 0, -1, True 997 998 def get_bound(bound, default): 999 if bound is None: 1000 return default 1001 elif type(bound) == ast.UnaryOp and type(bound.op) == ast.USub: 1002 return -bound.operand.n 1003 elif type(bound) == ast.Constant and type(bound.value) == int: 1004 return bound.n 1005 else: 1006 raise Unhandled( 1007 "Not handling bound: %s" % ast.dump(bound)) 1008 1009 right_closed = slc.upper is None 1010 1011 return get_bound(slc.lower, 0), get_bound(slc.upper, -1), right_closed 1012 1013 def convertTermName(term): 1014 python_clashes = { 1015 # These are underscored in the python driver to avoid 1016 # keywords, but they aren't java keywords so we convert 1017 # them back. 1018 'or_': 'Or', 1019 'and_': 'And', 1020 'not_': 'Not', 1021 } 1022 method_aliases = { 1023 'get_field': 'Field', 1024 'db': 'DB', 1025 'db_create': 'DBCreate', 1026 'db_drop': 'DBDrop', 1027 'db_list': 'DBList', 1028 'uuid': 'UUID', 1029 'geojson': 'GeoJSON', 1030 'js': 'JS', 1031 'json': 'JSON', 1032 'to_json': 'ToJSON', 1033 'to_json_string': 'ToJSON', 1034 'minval': 'MinVal', 1035 'maxval': 'MaxVal', 1036 'http': 'HTTP', 1037 'iso8601': 'ISO8601', 1038 'to_iso8601': 'ToISO8601', 1039 } 1040 1041 return python_clashes.get(term, method_aliases.get(term, camel(term))) 1042 1043 def visit_Attribute(self, node, emit_parens=True): 1044 is_toplevel_constant = False 1045 # if attr_matches("r.row", node): 1046 # elif is_name("r", node.value) and node.attr in self.TOPLEVEL_CONSTANTS: 1047 if is_name("r", node.value) and node.attr in self.TOPLEVEL_CONSTANTS: 1048 # Python has r.minval, r.saturday etc. We need to emit 1049 # r.minval() and r.saturday() 1050 is_toplevel_constant = True 1051 1052 initial = ReQLVisitor.convertTermName(node.attr) 1053 1054 self.visit(node.value) 1055 self.write(".") 1056 self.write(initial) 1057 if initial in GO_KEYWORDS: 1058 self.write('_') 1059 if emit_parens and is_toplevel_constant: 1060 self.write('()') 1061 1062 def visit_UnaryOp(self, node): 1063 if type(node.op) == ast.Invert: 1064 self.visit(node.operand) 1065 self.write(".Not()") 1066 else: 1067 super(ReQLVisitor, self).visit_UnaryOp(node) 1068 1069 def visit_Call(self, node): 1070 if (attr_equals(node.func, "attr", "index_create") and len(node.args) == 2): 1071 node.func.attr = 'index_create_func' 1072 1073 # We call the superclass first, so if it's going to fail 1074 # because of r.row or other things it fails first, rather than 1075 # hitting the checks in this method. Since everything is 1076 # written to a stringIO object not directly to a file, if we 1077 # bail out afterwards it's still ok 1078 super_result = super(ReQLVisitor, self).visit_Call(node) 1079 1080 # r.expr(v, 1) should be skipped 1081 if (attr_equals(node.func, "attr", "expr") and len(node.args) > 1): 1082 self.skip("the go driver only accepts one parameter to expr") 1083 # r.table_create("a", "b") should be skipped 1084 if (attr_equals(node.func, "attr", "table_create") and len(node.args) > 1): 1085 self.skip("the go driver only accepts one parameter to table_create") 1086 return super_result 1087 1088 class Renderer(object): 1089 '''Manages rendering templates''' 1090 1091 def __init__(self, invoking_filenames, source_files=None): 1092 self.template_file = './template.go.tpl' 1093 self.invoking_filenames = invoking_filenames 1094 self.source_files = source_files or [] 1095 self.tpl = Template(filename=self.template_file) 1096 self.template_context = { 1097 'EmptyTemplate': process_polyglot.EmptyTemplate, 1098 } 1099 1100 def render(self, 1101 template_name, 1102 output_dir, 1103 output_name=None, 1104 **kwargs): 1105 if output_name is None: 1106 output_name = template_name 1107 1108 output_path = output_dir + '/' + output_name 1109 1110 results = self.template_context.copy() 1111 results.update(kwargs) 1112 try: 1113 rendered = self.tpl.render(**results) 1114 except process_polyglot.EmptyTemplate: 1115 logger.debug(" Empty template: %s", output_path) 1116 return 1117 with codecs.open(output_path, "w", "utf-8") as outfile: 1118 logger.info("Rendering %s", output_path) 1119 outfile.write(self.autogenerated_header( 1120 self.template_file, 1121 output_path, 1122 self.invoking_filenames, 1123 )) 1124 outfile.write(rendered) 1125 1126 def autogenerated_header(self, template_path, output_path, filename): 1127 rel_tpl = os.path.relpath(template_path, start=output_path) 1128 filenames = ' and '.join(os.path.basename(f) 1129 for f in self.invoking_filenames) 1130 return ('// Code generated by {}.\n' 1131 '// Do not edit this file directly.\n' 1132 '// The template for this file is located at:\n' 1133 '// {}\n').format(filenames, rel_tpl) 1134 1135 def camel(varname): 1136 'CamelCase' 1137 if re.match(r'[A-Z][A-Z0-9_]*$|[a-z][a-z0-9_]*$', varname): 1138 # if snake-case (upper or lower) camelize it 1139 suffix = "_" if varname.endswith('_') else "" 1140 return ''.join(x.title() for x in varname.split('_')) + suffix 1141 else: 1142 # if already mixed case, just capitalize the first letter 1143 return varname[0].upper() + varname[1:] 1144 1145 1146 def dromedary(varname): 1147 'dromedaryCase' 1148 if re.match(r'[A-Z][A-Z0-9_]*$|[a-z][a-z0-9_]*$', varname): 1149 chunks = varname.split('_') 1150 suffix = "_" if varname.endswith('_') else "" 1151 return (chunks[0].lower() + 1152 ''.join(x.title() for x in chunks[1:]) + 1153 suffix) 1154 else: 1155 return varname[0].lower() + varname[1:] 1156 1157 def attr_equals(node, attr, value): 1158 '''Helper for digging into ast nodes''' 1159 return hasattr(node, attr) and getattr(node, attr) == value 1160 1161 if __name__ == '__main__': 1162 main()