gitee.com/quant1x/gox@v1.7.6/num/asm/asm2avo.py (about)

     1  """
     2  Hacky, incomplete script for translating clang AT&T assembly into Avo instructions.
     3  python asm2avo.py --help
     4  
     5  - Function names can't be mangled and must end with return type:
     6      _I: int, _B: bool, _F: float32, _D: float64, _V: void
     7  - Data segments not supported, add manually & make loads unaligned
     8  - Last argument is used to pass length of first slice
     9  - Mostly tested on clang 15 output from godbolt
    10  """
    11  import re
    12  import sys
    13  from copy import deepcopy
    14  from dataclasses import dataclass
    15  from typing import List, Union, Optional
    16  
    17  from parsimonious.grammar import Grammar
    18  from parsimonious.nodes import NodeVisitor
    19  
    20  grammar = Grammar(
    21      r"""
    22      program     = (signature / label / instruction / skip)*
    23      signature   = identifier "(" (type ("," nbws type)*)? "):" skip
    24      label       = identifier ":" skip
    25      instruction = identifier nbws (arg ("," nbws arg)*)? skip
    26  
    27      arg         = imm / reg / hex / string / mem / labelref
    28      imm         = "$" (integer / hex)
    29      reg         = "%" identifier
    30      hex         = ~"0x[0-9a-f]+"
    31      string      = ~"\"[^\"]*\""
    32      mem         = (integer / labelref)? "(" reg ("," nbws reg ("," nbws integer)?)? ")"
    33      labelref    = identifier ""
    34  
    35      integer     = ~"-?[0-9]+"
    36      identifier  = ~"[a-zA-Z0-9_.]+"
    37      type        = ("double" / "float" / "bool" / "char" / "int" / "unsigned int" / "long" / "unsigned long" / "void") (nbws "*")?
    38      ws          = ~"\s*" 
    39      nbws        = ~"[^\S\r\n]*"
    40      comment     = ~"#[^\r\n]*"
    41      skip        = (ws comment?)*
    42      """
    43  )
    44  
    45  # fmt: off
    46  
    47  @dataclass
    48  class CType:
    49      name: str
    50      pointer: bool
    51  
    52  @dataclass
    53  class Imm:
    54      value: str
    55  
    56  @dataclass
    57  class Hex:
    58      value: str
    59  
    60  @dataclass
    61  class String:
    62      value: str
    63  
    64  @dataclass
    65  class Identifier:
    66      value: str
    67  
    68  @dataclass
    69  class Reg:
    70      value: str
    71  
    72  @dataclass
    73  class LabelRef:
    74      value: str
    75  
    76  @dataclass
    77  class Mem:
    78      base: Reg
    79      index: Optional[Reg]
    80      scale: Optional[int]
    81      offset: Optional[Union[LabelRef, int]]
    82  
    83  @dataclass
    84  class Signature:
    85      fn_name: str
    86      ret_type: CType
    87      arg_types: List[CType]
    88  
    89  @dataclass
    90  class Label:
    91      name: str
    92  
    93  @dataclass
    94  class Instruction:
    95      op: str
    96      args: List[Union[Imm, Hex, String, Reg, Mem, LabelRef]]
    97  
    98  @dataclass
    99  class Program:
   100      statements: List[Union[Signature, Label, Instruction]]
   101  
   102  # fmt: on
   103  
   104  
   105  class AstBuilder(NodeVisitor):
   106  
   107      # top level statements
   108  
   109      def visit_program(self, node, visited_children):
   110          return Program(statements=[s[0] for s in visited_children if s[0] is not None])
   111  
   112      def visit_signature(self, node, visited_children):
   113          fn_name = visited_children[0].value
   114          if fn_name[-2:] not in ["_I", "_B", "_F", "_D", "_V"]:
   115              sys.exit("Function name must end in _I, _B, _F, _D or _V return type")
   116          ret_type = {
   117              "_I": CType("long", False),
   118              "_B": CType("bool", False),
   119              "_F": CType("float", False),
   120              "_D": CType("double", False),
   121              "_V": CType("void", False),
   122          }[fn_name[-2:]]
   123          arg_types = AstBuilder.collect_types(visited_children, [CType])
   124          return Signature(fn_name=fn_name[:-2], ret_type=ret_type, arg_types=arg_types)
   125  
   126      def visit_label(self, node, visited_children):
   127          return Label(name=visited_children[0].value)
   128  
   129      def visit_instruction(self, node, visited_children):
   130          valid_args = (Imm, Hex, String, Reg, Mem, LabelRef)
   131          op = visited_children[0].value.upper()
   132          args = AstBuilder.collect_types(visited_children, valid_args)
   133          return Instruction(op=op, args=args)
   134  
   135      def visit_skip(self, node, visited_children):
   136          return None
   137  
   138      # instruction args
   139  
   140      def visit_arg(self, node, visited_children):
   141          return visited_children[0]
   142  
   143      def visit_imm(self, node, visited_children):
   144          return Imm(value=node.text[1:])
   145  
   146      def visit_reg(self, node, visited_children):
   147          return Reg(value=node.text[1:].upper())
   148  
   149      def visit_hex(self, node, visited_children):
   150          return Hex(value=node.text)
   151  
   152      def visit_string(self, node, visited_children):
   153          return String(value=node.text[1:-1])
   154  
   155      def visit_mem(self, node, visited_children):
   156          args = AstBuilder.collect_types(visited_children, [int, LabelRef, Reg])
   157          if isinstance(args[0], Reg):
   158              args.insert(0, None)
   159          return Mem(
   160              offset=args[0],
   161              base=args[1],
   162              index=args[2] if len(args) > 2 else None,
   163              scale=args[3] if len(args) > 3 else 1 if len(args) > 2 else None,
   164          )
   165  
   166      def visit_labelref(self, node, visited_children):
   167          return LabelRef(value=node.text)
   168  
   169      # leaves
   170  
   171      def visit_integer(self, node, visited_children):
   172          return int(node.text)
   173  
   174      def visit_identifier(self, node, visited_children):
   175          return Identifier(value=node.text)
   176  
   177      def visit_type(self, node, visited_children):
   178          return CType(
   179              name=visited_children[0][0],
   180              pointer=node.text[-1] == "*",
   181          )
   182  
   183      def generic_visit(self, node, visited_children):
   184          return visited_children or node.text
   185  
   186      @staticmethod
   187      def collect_types(lst, types):
   188          ret = []
   189          if isinstance(lst, tuple(types)):
   190              ret.append(lst)
   191          elif isinstance(lst, list):
   192              for e in lst:
   193                  ret += AstBuilder.collect_types(e, types)
   194          return ret
   195  
   196  
   197  class Generator:
   198      def __init__(self, out=sys.stdout):
   199          self.indent = 0
   200          self.out = out
   201  
   202      def generate(self, ast):
   203          method_name = "generate_" + type(ast).__name__.lower()
   204          getattr(self, method_name, self.default)(ast)
   205  
   206      def default(self, ast):
   207          self.writeln(f"// FIXME (unhandled): {ast}")
   208  
   209      def write(self, s, indent=None):
   210          self.out.write(
   211              "{}{}".format("\t" * (self.indent if indent is None else indent), s)
   212          )
   213  
   214      def writeln(self, s, indent=None):
   215          self.write(s + "\n", indent=indent)
   216  
   217  
   218  class AvoGenerator(Generator):
   219      def __init__(self, *args, suffix=None, **kwargs):
   220          super().__init__(*args, **kwargs)
   221          self.suffix = suffix
   222          self.current_fn = None
   223          self.current_label = None
   224          self.ret_register = None
   225  
   226      def generate_program(self, ast: Program):
   227          for stmt in ast.statements:
   228              self.generate(stmt)
   229          self.close_label()
   230          self.close_fn()
   231  
   232      def generate_signature(self, ast: Signature):
   233          self.close_label()
   234          self.close_fn()
   235          self.open_fn(ast.fn_name)
   236  
   237          def go_type_name(ctype: CType):
   238              slice = "[]" if ctype.pointer else ""
   239              if ctype.name in ["int", "unsigned int", "long", "unsigned long"]:
   240                  return slice + "int"
   241              if ctype.name in ["char", "unsigned char"]:
   242                  return slice + "byte"
   243              if ctype.name in ["bool"]:
   244                  return slice + "bool"
   245              if ctype.name == "double":
   246                  return slice + "float64"
   247              if ctype.name == "float":
   248                  return slice + "float32"
   249              if ctype.name == "void":
   250                  return "uintptr" if ctype.pointer else "void"
   251              sys.exit(f"Unexpected type {ctype}")
   252  
   253          go_ret_type = go_type_name(ast.ret_type)
   254          if ast.ret_type.pointer or go_ret_type in ["bool", "byte", "int", "uintptr"]:
   255              self.ret_register = "RAX"
   256          elif go_ret_type in ["float32", "float64"]:
   257              self.ret_register = "X0"
   258          else:
   259              self.ret_register = None
   260  
   261          go_types = [go_type_name(t) for t in ast.arg_types]
   262          has_len_param = "[]" in "".join(go_types) and go_types[-1:] == ["int"]
   263          if has_len_param:
   264              go_types = go_types[:-1]
   265  
   266          ip_regs = ["R9", "R8", "RCX", "RDX", "RSI", "RDI"]
   267          fd_regs = ["X5", "X4", "X3", "X2", "X1", "X0"]
   268          scalars = ["f", "e", "d", "c", "b", "a"]
   269          vectors = ["w", "v", "u", "z", "y", "x"]
   270          params = []  # [('name', 'type', 'reg')..]
   271          for t in go_types:
   272              name = vectors.pop() if ("[]" in t) else scalars.pop()
   273              reg = fd_regs.pop() if ("float" in t and "[]" not in t) else ip_regs.pop()
   274              params.append((name, t, reg))
   275  
   276          self.write(f'TEXT("{self.add_suffix(ast.fn_name)}", NOSPLIT, "func(')
   277          prev_type = None
   278          for i, (name, t, reg) in enumerate(params):
   279              if i != 0:
   280                  if prev_type != t:
   281                      self.write(" " + prev_type + ", ", indent=0)
   282                  else:
   283                      self.write(", ", indent=0)
   284              self.write(name, indent=0)
   285              prev_type = t
   286          if prev_type is not None:
   287              self.write(" " + prev_type, indent=0)
   288          self.write(
   289              f"){(' '+go_ret_type) if go_ret_type != 'void' else ''}\")\n", indent=0
   290          )
   291          self.writeln('Pragma("noescape")')
   292  
   293          for (name, t, reg) in params:
   294              base = ".Base()" if "[]" in t else ""
   295              self.writeln(f'Load(Param("{name}"){base}, {reg})')
   296          if has_len_param:
   297              name = next(p[0] for p in params if "[]" in p[1])
   298              self.writeln(f'Load(Param("{name}").Len(), {ip_regs.pop()})')
   299          self.writeln("")
   300  
   301      def generate_label(self, ast: Label):
   302          self.close_label()
   303          self.open_label(ast.name.replace(".", ""))
   304  
   305      def generate_instruction(self, ast: Instruction):
   306          ast = self.replace_instruction(ast)
   307          if ast.op.startswith("."):
   308              self.writeln(f"// FIXME (unsupported): {ast}")
   309              return
   310          if ast.op.startswith("RET") and self.ret_register is not None:
   311              self.writeln(f"Store({self.ret_register}, ReturnIndex(0))")
   312          self.write(f"{ast.op}(")
   313          for i, arg in enumerate(ast.args):
   314              if isinstance(arg, (Hex, String)):
   315                  self.write(arg.value, indent=0)
   316              elif isinstance(arg, Imm):
   317                  if arg.value.startswith("0x") or int(arg.value) >= 0:
   318                      self.write(f"Imm({arg.value})", indent=0)
   319                  else:
   320                      self.write(f"I32({arg.value})", indent=0)
   321              else:
   322                  self.generate(arg)
   323              if i != len(ast.args) - 1:
   324                  self.write(", ", indent=0)
   325          self.writeln(")", indent=0)
   326  
   327      def generate_reg(self, ast: Reg):
   328          reg = ast.value
   329          if re.match(r"[XYZ]MM", reg):
   330              reg = reg[0] + reg[3:]
   331          if re.match(r"R\d+D$", reg):
   332              reg = reg[:-1] + "L"
   333          self.write(f"{reg}", indent=0)
   334  
   335      def generate_labelref(self, ast: LabelRef):
   336          self.write(f'LabelRef("{ast.value.replace(".", "")}")', indent=0)
   337  
   338      def generate_mem(self, ast: Mem):
   339          self.write(f"Mem{{Base: {ast.base.value}}}", indent=0)
   340          if ast.index is not None:
   341              self.write(f".Idx({ast.index.value}, {ast.scale})", indent=0)
   342          if ast.offset is not None:
   343              self.write(f".Offset({ast.offset})", indent=0)
   344  
   345      def open_fn(self, name):
   346          self.current_fn = name
   347          self.writeln("")
   348          self.writeln("func gen" + name + "() {\n")
   349          self.indent += 1
   350  
   351      def close_fn(self):
   352          if self.current_fn is not None:
   353              self.indent -= 1
   354              self.writeln("}")
   355              self.current_fn = None
   356  
   357      def open_label(self, name):
   358          self.current_label = name
   359          self.writeln("")
   360          self.writeln(f'Label("{name}")')
   361          self.writeln("{")
   362          self.indent += 1
   363  
   364      def close_label(self):
   365          if self.current_label is not None:
   366              self.indent -= 1
   367              self.writeln("}")
   368              self.current_label = None
   369  
   370      def add_suffix(self, name):
   371          m = re.match(r"(.*)_(F..)$", name)
   372          return (
   373              m.group(1) + "_" + self.suffix + "_" + m.group(2)
   374              if m and self.suffix
   375              else name
   376          )
   377  
   378      def replace_instruction(self, ast: Instruction):
   379          if ast.op.startswith("RET"):
   380              return Instruction(op="RET", args=[])
   381          if re.match(r"CMP[QBL]$", ast.op):
   382              return Instruction(op=ast.op, args=[ast.args[1], ast.args[0]])
   383          if re.match(r"VCVTTSD2SI$", ast.op) and ast.args[-1].value[0] == "R":
   384              return Instruction(op=ast.op + "Q", args=ast.args)
   385          if re.match(r"VCVTPD2PS$", ast.op) and ast.args[0].value[0] == "Y":
   386              return Instruction(op=ast.op + "Y", args=ast.args)
   387          if re.match(r"CLTQ$", ast.op):
   388              return Instruction(op="CDQE", args=ast.args)
   389          if re.match(r"MOVSLQ$", ast.op):
   390              return Instruction(op="MOVLQSX", args=ast.args)
   391          m = re.match(r"(VCVTSI2S[SD])(\w*)$", ast.op)
   392          if m and not m.group(2)[:1] in ["Q", "L"]:
   393              op = m.group(1) + "Q" + m.group(2)
   394              return Instruction(op=op, args=ast.args)
   395          m = re.match(r"CMOV([AB])([QL])$", ast.op)
   396          if m:
   397              cmp = {  # https://docs.oracle.com/cd/E19120-01/open.solaris/817-5477/6mkuavhs7/index.html
   398                  "A": "GT",
   399                  "B": "LT",
   400              }[
   401                  m.group(1)
   402              ]
   403              return Instruction(op="CMOV" + m.group(2) + cmp, args=ast.args)
   404          m = re.match(r"VCMP(\w+?)(PD|PS|SD|SS)$", ast.op)
   405          if m:
   406              cmp = m.group(1)
   407              op = "VCMP" + m.group(2)
   408              imm = {  # https://www.felixcloutier.com/x86/cmppd.html
   409                  "EQ": 0,
   410                  "LT": 1,
   411                  "LE": 2,
   412                  "UNORD": 3,
   413                  "NEQ": 4,
   414                  "NLT": 5,
   415                  "NLE": 6,
   416                  "ORD": 7,
   417                  "EQ_UQ": 8,
   418                  "NGE": 9,
   419                  "NGT": 10,
   420                  "FALSE": 11,
   421                  "NEQ_OQ": 12,
   422                  "GE": 13,
   423                  "GT": 14,
   424                  "TRUE": 15,
   425                  "EQ_OS": 16,
   426                  "LT_OQ": 17,
   427                  "LE_OQ": 18,
   428                  "UNORD_S": 19,
   429                  "NEQ_US": 20,
   430                  "NLT_UQ": 21,
   431                  "NLE_UQ": 22,
   432                  "ORD_S": 23,
   433                  "EQ_US": 24,
   434                  "NGE_UQ": 25,
   435                  "NGT_UQ": 26,
   436                  "FALSE_OS": 27,
   437                  "NEQ_OS": 28,
   438                  "GE_OQ": 29,
   439                  "GT_OQ": 30,
   440                  "TRUE_US": 31,
   441              }[cmp]
   442              return Instruction(op=op, args=[Imm(value=str(imm))] + ast.args)
   443          return ast
   444  
   445  
   446  def main():
   447      import argparse
   448      import pprint
   449  
   450      parser = argparse.ArgumentParser()
   451      parser.add_argument("file", help="input file, omit for stdin", nargs="?")
   452      parser.add_argument("--out", help="output to file", required=False)
   453      parser.add_argument("--suffix", help="add suffix to function names", required=False)
   454      parser.add_argument("--cst", help="print concrete syntax tree", action="store_true")
   455      parser.add_argument("--ast", help="print abstract syntax tree", action="store_true")
   456      parsed = parser.parse_args()
   457  
   458      inp = open(parsed.file, "r").read() if parsed.file else sys.stdin.read()
   459      cst = grammar.parse(inp)
   460      if parsed.cst:
   461          print(cst, "\n")
   462      ast = AstBuilder().visit(cst)
   463      if parsed.ast:
   464          pprint.pprint(ast, width=100)
   465      out = open(parsed.out, "w") if parsed.out else sys.stdout
   466      gen = AvoGenerator(out=out, suffix=parsed.suffix)
   467      gen.generate(ast)
   468  
   469  
   470  if __name__ == "__main__":
   471      main()