github.com/bytedance/sonic@v1.11.7-0.20240517092252-d2edb31b167b/tools/asm2arm/arm.py (about)

     1  #!/usr/bin/env python3
     2  # -*- coding: utf-8 -*-
     3  
     4  import os
     5  import sys
     6  import string
     7  import argparse
     8  import itertools
     9  import functools
    10  
    11  from typing import Any
    12  from typing import Dict
    13  from typing import List
    14  from typing import Type
    15  from typing import Tuple
    16  from typing import Union
    17  from typing import Callable
    18  from typing import Iterable
    19  from typing import Optional
    20  
    21  import mcasm
    22  
    23  class InstrStreamer(mcasm.Streamer):
    24      data:   bytes
    25      instr:  Optional[mcasm.mc.Instruction]
    26      fixups: List[mcasm.mc.Fixup]
    27  
    28      def __init__(self):
    29          self.data = b''
    30          self.instr = None
    31          self.fixups = []
    32          super().__init__()
    33  
    34      def unhandled_event(self, name: str, base_impl, *args, **kwargs):
    35          if name == 'emit_instruction':
    36              self.instr = args[1]
    37              self.data = args[2]
    38              self.fixups = args[3]
    39          return super().unhandled_event(name, base_impl, *args, **kwargs)
    40  
    41  ### Instruction Parser (GAS Syntax) ###
    42  class Token:
    43      tag: int
    44      val: Union[int, str]
    45  
    46      def __init__(self, tag: int, val: Union[int, str]):
    47          self.val = val
    48          self.tag = tag
    49  
    50      @classmethod
    51      def end(cls):
    52          return cls(TOKEN_END, '')
    53  
    54      @classmethod
    55      def reg(cls, reg: str):
    56          return cls(TOKEN_REG, reg)
    57  
    58      @classmethod
    59      def imm(cls, imm: int):
    60          return cls(TOKEN_IMM, imm)
    61  
    62      @classmethod
    63      def num(cls, num: int):
    64          return cls(TOKEN_NUM, num)
    65  
    66      @classmethod
    67      def name(cls, name: str):
    68          return cls(TOKEN_NAME, name)
    69  
    70      @classmethod
    71      def punc(cls, punc: str):
    72          return cls(TOKEN_PUNC, punc)
    73  
    74      def __repr__(self):
    75          if self.tag == TOKEN_END:
    76              return '<END>'
    77          elif self.tag == TOKEN_REG:
    78              return '<REG %s>' % self.val
    79          elif self.tag == TOKEN_IMM:
    80              return '<IMM %d>' % self.val
    81          elif self.tag == TOKEN_NUM:
    82              return '<NUM %d>' % self.val
    83          elif self.tag == TOKEN_NAME:
    84              return '<NAME %s>' % repr(self.val)
    85          elif self.tag == TOKEN_PUNC:
    86              return '<PUNC %s>' % repr(self.val)
    87          else:
    88              return '<UNK:%d %r>' % (self.tag, self.val)
    89  
    90  class Label:
    91      name: str
    92      offs: Optional[int]
    93  
    94      def __init__(self, name: str):
    95          self.name = name
    96          self.offs = None
    97  
    98      def __str__(self):
    99          return self.name
   100  
   101      def __repr__(self):
   102          if self.offs is None:
   103              return '{LABEL %s (unresolved)}' % self.name
   104          else:
   105              return '{LABEL %s (offset: %d)}' % (self.name, self.offs)
   106  
   107      def resolve(self, offs: int):
   108          self.offs = offs
   109  
   110  class Index:
   111      base  : 'Register'
   112      scale : int
   113  
   114      def __init__(self, base: 'Register', scale: int = 1):
   115          self.base  = base
   116          self.scale = scale
   117  
   118      def __str__(self):
   119          if self.scale == 1:
   120              return ',%s' % self.base
   121          elif self.scale >= 2:
   122              return ',%s,%d' % (self.base, self.scale)
   123          else:
   124              raise RuntimeError('invalid parser state: invalid scale')
   125  
   126      def __repr__(self):
   127          if self.scale == 1:
   128              return repr(self.base)
   129          elif self.scale >= 2:
   130              return '%d * %r' % (self.scale, self.base)
   131          else:
   132              raise RuntimeError('invalid parser state: invalid scale')
   133  
   134  class Memory:
   135      base  : Optional['Register']
   136      disp  : Optional['Displacement']
   137      index : Optional[Index]
   138  
   139      def __init__(self, base: Optional['Register'], disp: Optional['Displacement'], index: Optional[Index]):
   140          self.base  = base
   141          self.disp  = disp
   142          self.index = index
   143          self._validate()
   144  
   145      def __str__(self):
   146          return '%s(%s%s)' % (
   147              '' if self.disp  is None else self.disp,
   148              '' if self.base  is None else self.base,
   149              '' if self.index is None else self.index
   150          )
   151  
   152      def __repr__(self):
   153          return '{MEM %r%s%s}' % (
   154              '' if self.base  is None else self.base,
   155              '' if self.index is None else ' + ' + repr(self.index),
   156              '' if self.disp  is None else ' + ' + repr(self.disp)
   157          )
   158  
   159      def _validate(self):
   160          if self.base is None and self.index is None:
   161              raise SyntaxError('either base or index must be specified')
   162  
   163  class Register:
   164      reg: str
   165  
   166      def __init__(self, reg: str):
   167          self.reg = reg.lower()
   168  
   169      def __str__(self):
   170          return '%' + self.reg
   171  
   172      def __repr__(self):
   173          return '{REG %s}' % self.reg
   174  
   175  class Immediate:
   176      val: int
   177      ref: str
   178  
   179      def __init__(self, val: int):
   180          self.ref = ''
   181          self.val = val
   182  
   183      def __str__(self):
   184          return '$%d' % self.val
   185  
   186      def __repr__(self):
   187          return '{IMM bin:%s, oct:%s, dec:%d, hex:%s}' % (
   188              bin(self.val)[2:],
   189              oct(self.val)[2:],
   190              self.val,
   191              hex(self.val)[2:],
   192          )
   193  
   194  class Reference:
   195      ref: str
   196      disp: int
   197      off: Optional[int]
   198  
   199      def __init__(self, ref: str, disp: int = 0):
   200          self.ref = ref
   201          self.disp = disp
   202          self.off = None
   203  
   204      def __str__(self):
   205          if self.off is None:
   206              return self.ref
   207          else:
   208              return '$' + str(self.off)
   209  
   210      def __repr__(self):
   211          if self.off is None:
   212              return '{REF %s + %d (unresolved)}' % (self.ref, self.disp)
   213          else:
   214              return '{REF %s + %d (offset: %d)}' % (self.ref, self.disp, self.off)
   215  
   216      @property
   217      def offset(self) -> int:
   218          if self.off is None:
   219              raise SyntaxError('unresolved reference to ' + repr(self.ref))
   220          else:
   221              return self.off
   222  
   223      def resolve(self, off: int):
   224          self.off = self.disp + off
   225  
   226  Operand = Union[
   227      Label,
   228      Memory,
   229      Register,
   230      Immediate,
   231      Reference,
   232  ]
   233  
   234  Displacement = Union[
   235      Immediate,
   236      Reference,
   237  ]
   238  
   239  TOKEN_END  = 0
   240  TOKEN_REG  = 1
   241  TOKEN_IMM  = 2
   242  TOKEN_NUM  = 3
   243  TOKEN_NAME = 4
   244  TOKEN_PUNC = 5
   245  
   246  ARM_ADRP_IMM_BIT_SIZE = 21
   247  ARM_ADR_WIDTH = 1024 * 1024
   248  
   249  class Instruction:
   250      comments:   str
   251      mnemonic:   str
   252      asm_code:   str
   253      data:       bytes
   254      instr:      Optional[mcasm.mc.Instruction]
   255      fixups:     List[mcasm.mc.Fixup]
   256      offs_:      Optional[int]
   257      ADRP_label: Optional[str]
   258      text_label: Optional[str]
   259      back_label: Optional[str]
   260      ADR_instr:  Optional[str]
   261      adrp_asm:   Optional[str]
   262      is_adrp:    bool
   263  
   264      def __init__(self, line: str, adrp_count=0):
   265          self.comments = ''
   266          self.offs_ = None
   267          self.is_adrp = False
   268          self.asm = mcasm.Assembler('aarch64-apple-macos11')
   269  
   270          self.parse(line, adrp_count)
   271  
   272      def __str__(self):
   273          return self.asm_code
   274  
   275      def __repr__(self):
   276          return '{INSTR %s}' % ( self.asm_code )
   277      
   278      @property
   279      def jmptab(self) -> Optional[str]:
   280          if self.is_adrp and self.label_name.find(CLANG_JUMPTABLE_LABLE) != -1:
   281              return self.label_name
   282  
   283      @property
   284      def size(self) -> int:
   285          return len(self.data)
   286  
   287      @functools.cached_property
   288      def label_name(self) -> Optional[str]:
   289          if len(self.fixups) > 1:
   290              raise RuntimeError('has more than 1 fixup: ' + self.asm_code)
   291          if self.need_reloc:
   292              if self.mnemonic == 'adr':
   293                  return self.fixups[0].value.sub_expr.symbol.name
   294              else:
   295                  return self.fixups[0].value.symbol.name
   296          else:
   297              return None
   298  
   299      @functools.cached_property
   300      def is_branch(self) -> bool:
   301          return self.instr.desc.is_branch or self.is_invoke
   302  
   303      @functools.cached_property
   304      def is_return(self) -> bool:
   305          return self.instr.desc.is_return
   306  
   307      @functools.cached_property
   308      def is_jmpq(self) -> bool:
   309          # return self.mnemonic == 'br'
   310          return False
   311  
   312      @functools.cached_property
   313      def is_jmp(self) -> bool:
   314          return self.mnemonic == 'b'
   315  
   316      @functools.cached_property
   317      def is_invoke(self) -> bool:
   318          return self.instr.desc.is_call
   319  
   320      @property
   321      def is_branch_label(self) -> bool:
   322          return self.is_branch and (len(self.fixups) != 0)
   323  
   324      @property
   325      def need_reloc(self) -> bool:
   326          return (len(self.fixups) != 0)
   327  
   328      def set_label_offset(self, off):
   329          # arm64
   330          self.offs_ = off + 4
   331  
   332      # def _encode_normal_instr(self) -> str:
   333      #     return self.encode(self.data, self.asm_code)
   334  
   335      @functools.cached_property
   336      def encoded(self) -> str:
   337          if self.need_reloc:
   338              return self._encode_reloc_instr()
   339          else:
   340              return self._encode_normal_instr()
   341  
   342      def _check_offs_is_valid(self, bit_size: int):
   343          if abs(self.offs_) > (1 << bit_size):
   344              raise RuntimeError('offset is too larger, [assembly]: %s, [offset]: %d, [valid off size]: %d'
   345                  % (self.asm_code, self.offs_, self.fixups[0].kind_info.bit_size))
   346  
   347      def _encode_adr(self):
   348          buf = int.from_bytes(self.data, byteorder='little')
   349          bit_size = ARM_ADRP_IMM_BIT_SIZE
   350  
   351          self._check_offs_is_valid(bit_size)
   352  
   353          # adrp op: | op | immlo | 1 0 0 0 0 | immhi | Rd |
   354          #          |31  |30   29|28       24|23    5|4  0|
   355          imm_lo = (self.offs_ << 29) & 0x60000000
   356          imm_hi = (self.offs_ << 3) & 0x00FFFFE0
   357          encode_data = (buf + imm_lo + imm_hi).to_bytes(4, byteorder='little')
   358          self.data = encode_data
   359          # return self.encode(encode_data, '%s $%s(%%rip)' % (str(self), self.offs_))
   360  
   361      def _encode_rel32(self):
   362          if self.mnemonic == 'adrp' or self.mnemonic == 'adr':
   363              return self._encode_adr()
   364          buf = int.from_bytes(self.data, byteorder='little')
   365  
   366          imm = self.offs_
   367          imm_size = self.fixups[0].kind_info.bit_size
   368          imm_offset = self.fixups[0].kind_info.bit_offset
   369          if self.fixups[0].kind_info.is_pc_rel == 1:
   370              # except adr and adrp, other PC-releative instructions need times 4
   371              imm = imm >> 2
   372              # immediate bit size has 1-bit for sign
   373              self._check_offs_is_valid(imm_size - 1 + 2)
   374          else:
   375              self._check_offs_is_valid(imm_size)
   376  
   377          imm = imm << imm_offset
   378          mask = (0x1 << (imm_size + imm_offset)) - 1
   379          buf = buf | (imm & mask)
   380          buf = buf.to_bytes(4, byteorder='little')
   381          self.data = buf
   382          # return self.encode(buf, '%s $%s(%%rip)' % (str(self), self.offs_))
   383  
   384      def _encode_page(self):
   385          if self.mnemonic != 'adrp':
   386              raise RuntimeError("not adrp instruction: %s" % self.asm_code)
   387          self.offs_ = self.offs_ >> 12
   388          return self._encode_rel32()
   389  
   390      def _encode_pageoff(self):
   391          self.offs_ = 0
   392          return self._encode_rel32()
   393  
   394      def _fixup_rel32(self):
   395          if self.offs_ is None:
   396              raise RuntimeError('unresolved label %s' % self.label_name)
   397  
   398          if self.mnemonic == 'adr':
   399              self._encode_adr()
   400          elif self.fixups[0].value.variant_kind == mcasm.mc.SymbolRefExpr.VariantKind.PAGEOFF:
   401              self._encode_pageoff()
   402          elif self.fixups[0].value.variant_kind == mcasm.mc.SymbolRefExpr.VariantKind.PAGE:
   403              self._encode_page()
   404          else:
   405              self._encode_rel32()
   406  
   407      def _encode_reloc_instr(self) -> str:
   408          self._fixup_rel32()
   409          return self.encode(self.data, '%s $%s(%%rip)' % (str(self), self.offs_))
   410  
   411      def _encode_normal_instr(self) -> str:
   412          return self.encode(self.data, str(self))
   413  
   414      def _raw_instr(self) -> bytes:
   415          if self.need_reloc:
   416              self._fixup_rel32()
   417          return self.data
   418  
   419      def _fixup_adrp(self, line: str, adrp_count: int) -> str:
   420          reg = line.split()[1].split(',')[0]
   421          self.text_label = line.split()[2].split('@')[0]
   422          self.ADRP_label = self.text_label + '_' + reg + '_' + str(adrp_count)
   423          self.back_label = '_back_adrp_' + str(adrp_count)
   424          self.ADR_instr = 'adr ' + reg + ', ' + self.text_label
   425          self.adrp_asm = line
   426          self.is_adrp = True
   427          line = 'b ' + self.ADRP_label
   428          self.asm_code = line + ' // ' + self.adrp_asm
   429  
   430          return line
   431  
   432      def _parse_by_mcasm(self, line: str):
   433          streamer = InstrStreamer()
   434          # self.asm.assemble(streamer, line, MCPU="", features_str="")
   435          self.asm.assemble(streamer, line)
   436          if streamer.instr is None:
   437              raise RuntimeError('cannot parse assembly: %s' % line)
   438          self.instr = streamer.instr
   439  
   440          # instead of short jump instruction
   441          self.data = streamer.data
   442  
   443          self.fixups = streamer.fixups
   444          self.mnemonic = line.split()[0]
   445  
   446      def convert_to_adr(self):
   447          self.is_adrp = True
   448          adr_asm = self.adrp_asm.replace('adrp', 'adr')
   449          self.asm_code = adr_asm + ' // ' + self.adrp_asm
   450          # self._parse_by_mcasm(adr_asm)
   451          return self.asm_code
   452  
   453      def parse(self, line: str, adrp_count: int):
   454          # machine code
   455          menmonic = line.split()[0]
   456  
   457          self.ADRP_label = None
   458          self.text_label = None
   459          # turn adrp to jmp
   460          if (menmonic == 'adrp'):
   461              line = self.convert_to_adr()
   462          else:
   463              self.asm_code = line
   464  
   465          self._parse_by_mcasm(line)
   466  
   467      @staticmethod
   468      def encode(buf: bytes, comments: str = '') -> str:
   469          i = 0
   470          r = []
   471          n = len(buf)
   472  
   473          # @debug
   474          # while i < n - 3:
   475          #     r.append('%08x' % int.from_bytes(buf[i:i + 4], 'little'))
   476          #     i += 4
   477          # return '\n\t'.join(r)
   478  
   479          if (n % 4 != 0):
   480              raise RuntimeError("Unkown instruction which not encoding 4 bytes: %s " % comments, buf)
   481  
   482          while i < n - 3:
   483              r.append('WORD $0x%08x' % int.from_bytes(buf[i:i + 4], 'little'))
   484              i += 4
   485  
   486          # join them together, and attach the comment if any
   487          if not comments:
   488              return '; '.join(r)
   489          else:
   490              return '%s  // %s' % ('; '.join(r), comments)
   491  
   492      Reg  = Optional[Register]
   493      Disp = Optional[Displacement]
   494  
   495  ### Prototype Parser ###
   496  
   497  ARGS_ORDER_C = [
   498      Register('x0'),
   499      Register('x1'),
   500      Register('x2'),
   501      Register('x3'),
   502      Register('x4'),
   503      Register('x5'),
   504      Register('x6'),
   505      Register('x7'),
   506  ]
   507  
   508  ARGS_ORDER_GO = [
   509      Register('R0'),
   510      Register('R1'),
   511      Register('R2'),
   512      Register('R3'),
   513      Register('R4'),
   514      Register('R5'),
   515      Register('R6'),
   516      Register('R7'),
   517  ]
   518  
   519  FPARGS_ORDER = [
   520      Register('D0'),
   521      Register('D1'),
   522      Register('D2'),
   523      Register('D3'),
   524      Register('D4'),
   525      Register('D5'),
   526      Register('D6'),
   527      Register('D7'),
   528  ]
   529  
   530  class Parameter:
   531      name : str
   532      size : int
   533      creg : Register
   534      goreg: Register
   535  
   536      def __init__(self, name: str, size: int, reg: Register, goreg: Register):
   537          self.creg  = reg
   538          self.goreg = reg
   539          self.name = name
   540          self.size = size
   541  
   542      def __repr__(self):
   543          return '<ARG %s(%d): %s>' % (self.name, self.size, self.creg)
   544  
   545  class Pcsp:
   546      entry: int
   547      maxpc: int
   548      out  : List[Tuple[int, int]]
   549      pc   : int
   550      sp   : int
   551  
   552      def __init__(self, entry: int):
   553          self.out = []
   554          self.maxpc = entry
   555          self.entry = entry
   556          self.pc = entry
   557          self.sp = 0
   558  
   559      def __str__(self) -> str:
   560          ret = '[][2]uint32{\n'
   561          for pc, sp in self.out:
   562              ret += '        {%d, %d},\n' % (pc, sp)
   563          return ret + '    }'
   564  
   565      def optimize(self):
   566          # push the last record
   567          self.out.append((self.pc - self.entry, self.sp))
   568          # sort by pc
   569          self.out.sort(key=lambda x: x[0])
   570          # NOTICE: first pair {1, 0} to be compitable with golang
   571          tmp = [(1, 0)]
   572          lpc, lsp = 0, -1
   573          for pc, sp in self.out:
   574              # sp changed, push new record
   575              if pc != lpc and sp != lsp:
   576                      tmp.append((pc, sp))
   577              # sp unchanged, replace with the higher pc
   578              if pc != lpc and sp == lsp:
   579                  if len(tmp) > 0:
   580                      tmp.pop(-1)
   581                  tmp.append((pc, sp))
   582  
   583              lpc, lsp = pc, sp
   584          self.out = tmp
   585  
   586      def update(self, dpc: int, dsp: int):
   587          self.out.append((self.pc - self.entry, self.sp))
   588          self.pc += dpc
   589          self.sp += dsp
   590          if self.pc > self.maxpc:
   591              self.maxpc = self.pc
   592  
   593  class Prototype:
   594      args: List[Parameter]
   595      retv: Optional[Parameter]
   596  
   597      def __init__(self, retv: Optional[Parameter], args: List[Parameter]):
   598          self.retv = retv
   599          self.args = args
   600  
   601      def __repr__(self):
   602          if self.retv is None:
   603              return '<PROTO (%s)>' % repr(self.args)
   604          else:
   605              return '<PROTO (%r) -> %r>' % (self.args, self.retv)
   606  
   607      @property
   608      def argspace(self) -> int:
   609          return sum(
   610              [v.size for v in self.args],
   611              (0 if self.retv is None else self.retv.size)
   612          )
   613          
   614      @property
   615      def inputspace(self) -> int:
   616          return sum([v.size for v in self.args])
   617  
   618  class PrototypeMap(Dict[str, Prototype]):
   619      @staticmethod
   620      def _dv(c: str) -> int:
   621          if c == '(':
   622              return 1
   623          elif c == ')':
   624              return -1
   625          else:
   626              return 0
   627  
   628      @staticmethod
   629      def _tk(s: str, p: str) -> bool:
   630          return s.startswith(p) and (s == p or s[len(p)].isspace())
   631  
   632      @classmethod
   633      def _punc(cls, s: str) -> bool:
   634          return s in cls.__puncs_
   635  
   636      @staticmethod
   637      def _err(msg: str) -> SyntaxError:
   638          return SyntaxError(
   639              msg + ', ' +
   640              'the parser integrated in this tool is just a text-based parser, ' +
   641              'so please keep the companion .go file as simple as possible and do not use defined types'
   642          )
   643  
   644      @staticmethod
   645      def _align(nb: int) -> int:
   646          return (((nb - 1) >> 3) + 1) << 3
   647  
   648      @classmethod
   649      def _retv(cls, ret: str) -> Tuple[str, int, Register, Register]:
   650          name, size, xmm = cls._args(ret)
   651          reg = Register('d0') if xmm else Register('x0')
   652          return name, size, reg, reg
   653  
   654      @classmethod
   655      def _args(cls, arg: str, sv: str = '') -> Tuple[str, int, bool]:
   656          while True:
   657              if not arg:
   658                  raise SyntaxError('missing type for parameter: ' + sv)
   659              elif arg[0] != '_' and not arg[0].isalnum():
   660                  return (sv,) + cls._size(arg.strip())
   661              elif not sv and arg[0].isdigit():
   662                  raise SyntaxError('invalid character: ' + repr(arg[0]))
   663              else:
   664                  sv += arg[0]
   665                  arg = arg[1:]
   666  
   667      @classmethod
   668      def _size(cls, name: str) -> Tuple[int, bool]:
   669          if name[0] == '*':
   670              return cls._align(8), False
   671          elif name in ('int8', 'uint8', 'byte', 'bool'):
   672              return cls._align(1), False
   673          elif name in ('int16', 'uint16'):
   674              return cls._align(2), False
   675          elif name == 'float32':
   676              return cls._align(4), True
   677          elif name in ('int32', 'uint32', 'rune'):
   678              return cls._align(4), False
   679          elif name == 'float64':
   680              return cls._align(8), True
   681          elif name in ('int64', 'uint64', 'uintptr', 'int', 'Pointer', 'unsafe.Pointer'):
   682              return cls._align(8), False
   683          else:
   684              raise cls._err('unrecognized type "%s"' % name)
   685  
   686      @classmethod
   687      def _func(cls, src: List[str], idx: int, depth: int = 0) -> Tuple[str, int]:
   688          for i in range(idx, len(src)):
   689              for x in map(cls._dv, src[i]):
   690                  if depth + x >= 0:
   691                      depth += x
   692                  else:
   693                      raise cls._err('encountered ")" more than "(" on line %d' % (i + 1))
   694              else:
   695                  if depth == 0:
   696                      return ' '.join(src[idx:i + 1]), i + 1
   697          else:
   698              raise cls._err('unexpected EOF when parsing function signatures')
   699  
   700      @classmethod
   701      def parse(cls, src: str) -> Tuple[str, 'PrototypeMap']:
   702          idx = 0
   703          pkg = ''
   704          ret = PrototypeMap()
   705          buf = src.splitlines()
   706  
   707          # scan through all the lines
   708          while idx < len(buf):
   709              line = buf[idx]
   710              line = line.strip()
   711  
   712              # skip empty lines
   713              if not line:
   714                  idx += 1
   715                  continue
   716  
   717              # check for package name
   718              if cls._tk(line, 'package'):
   719                  idx, pkg = idx + 1, line[7:].strip().split()[0]
   720                  continue
   721  
   722              if OUTPUT_RAW:
   723  
   724                  # extract funcname like "[var ]{funcname} = func(..."
   725                  end = line.find('func(')
   726                  if end == -1:
   727                      idx += 1
   728                      continue
   729                  name = line[:end].strip()
   730                  if name.startswith('var '):
   731                      name = name[4:].strip()
   732  
   733                  # function names must be identifiers
   734                  if not name.isidentifier():
   735                      raise cls._err('invalid function prototype: ' + name)
   736  
   737                  # register a empty prototype
   738                  ret[name] = Prototype(None, [])
   739                  idx += 1
   740  
   741              else:
   742  
   743                  # only cares about those functions that does not have bodies
   744                  if line[-1] == '{' or not cls._tk(line, 'func'):
   745                      idx += 1
   746                      continue
   747  
   748                  # prevent type-aliasing primitive types into other names
   749                  if cls._tk(line, 'type'):
   750                      raise cls._err('please do not declare any type with in the companion .go file')
   751  
   752                  # find the next function declaration
   753                  decl, pos = cls._func(buf, idx)
   754                  func, idx = decl[4:].strip(), pos
   755  
   756                  # find the beginning '('
   757                  nd = 1
   758                  pos = func.find('(')
   759  
   760                  # must have a '('
   761                  if pos == -1:
   762                      raise cls._err('invalid function prototype: ' + decl)
   763  
   764                  # extract the name and signature
   765                  args = ''
   766                  name = func[:pos].strip()
   767                  func = func[pos + 1:].strip()
   768  
   769                  # skip the method declaration
   770                  if not name:
   771                      continue
   772  
   773                  # function names must be identifiers
   774                  if not name.isidentifier():
   775                      raise cls._err('invalid function prototype: ' + decl)
   776  
   777                  # extract the argument list
   778                  while nd and func:
   779                      nch  = func[0]
   780                      func = func[1:]
   781  
   782                      # adjust the nesting level
   783                      nd   += cls._dv(nch)
   784                      args += nch
   785  
   786                  # check for EOF
   787                  if not nd:
   788                      func = func.strip()
   789                  else:
   790                      raise cls._err('unexpected EOF when parsing function prototype: ' + decl)
   791  
   792                  # check for multiple returns
   793                  if ',' in func:
   794                      raise cls._err('can only return a single value (detected by looking for "," within the return list)')
   795  
   796                  # check for return signature
   797                  if not func:
   798                      retv = None
   799                  elif func[0] == '(' and func[-1] == ')':
   800                      retv = Parameter(*cls._retv(func[1:-1]))
   801                  else:
   802                      raise SyntaxError('badly formatted return argument (please use parenthesis and proper arguments naming): ' + func)
   803  
   804                  # extract the argument list
   805                  if not args[:-1]:
   806                      args, alens, axmm = [], [], []
   807                  else:
   808                      args, alens, axmm = list(zip(*[cls._args(v.strip()) for v in args[:-1].split(',')]))
   809  
   810                  # check for the result
   811                  cregs = []
   812                  goregs = []
   813                  idxs = [0, 0]
   814  
   815                  # split the integer & floating point registers
   816                  for xmm in axmm:
   817                      key = 0 if xmm else 1
   818                      seq = FPARGS_ORDER if xmm else ARGS_ORDER_C
   819                      goseq = FPARGS_ORDER if xmm else ARGS_ORDER_GO
   820  
   821                      # check the argument count
   822                      if idxs[key] >= len(seq):
   823                          raise cls._err("too many arguments, consider pack some into a pointer")
   824  
   825                      # add the register
   826                      cregs.append(seq[idxs[key]])
   827                      goregs.append(goseq[idxs[key]])
   828                      idxs[key] += 1
   829  
   830                  # register the prototype
   831                  ret[name] = Prototype(retv, [
   832                      Parameter(arg, size, creg, goreg)
   833                      for arg, size, creg, goreg in zip(args, alens, cregs, goregs)
   834                  ])
   835  
   836          # all done
   837          return pkg, ret
   838  
   839  ### Assembly Source Parser ###
   840  
   841  ESC_IDLE = 0    # escape parser is idleing
   842  ESC_ISTR = 1    # currently inside a string
   843  ESC_BKSL = 2    # encountered backslash, prepare for escape sequences
   844  ESC_HEX0 = 3    # expect the first hexadecimal character of a "\x" escape
   845  ESC_HEX1 = 4    # expect the second hexadecimal character of a "\x" escape
   846  ESC_OCT1 = 5    # expect the second octal character of a "\000" escape
   847  ESC_OCT2 = 6    # expect the third octal character of a "\000" escape
   848  
   849  class Command:
   850      cmd  : str
   851      args : List[Union[str, bytes]]
   852  
   853      def __init__(self, cmd: str, args: List[Union[str, bytes]]):
   854          self.cmd  = cmd
   855          self.args = args
   856  
   857      def __repr__(self):
   858          return '<CMD %s %s>' % (self.cmd, ', '.join(map(repr, self.args)))
   859  
   860      @classmethod
   861      def parse(cls, src: str) -> 'Command':
   862          val = src.split(None, 1)
   863          cmd = val[0]
   864  
   865          # no parameters
   866          if len(val) == 1:
   867              return cls(cmd, [])
   868  
   869          # extract the argument string
   870          idx = 0
   871          esc = 0
   872          pos = None
   873          args = []
   874          vstr = val[1]
   875  
   876          # scan through the whole string
   877          while idx < len(vstr):
   878              nch = vstr[idx]
   879              idx += 1
   880  
   881              # mark the start of the argument
   882              if pos is None:
   883                  pos = idx - 1
   884  
   885              # encountered the delimiter outside of a string
   886              if nch == ',' and esc == ESC_IDLE:
   887                  pos, p = None, pos
   888                  args.append(vstr[p:idx - 1].strip())
   889  
   890              # start of a string
   891              elif nch == '"' and esc == ESC_IDLE:
   892                  esc = ESC_ISTR
   893  
   894              # end of string
   895              elif nch == '"' and esc == ESC_ISTR:
   896                  esc = ESC_IDLE
   897                  pos, p = None, pos
   898                  args.append(vstr[p:idx].strip()[1:-1].encode('utf-8').decode('unicode_escape'))
   899  
   900              # escape characters
   901              elif nch == '\\' and esc == ESC_ISTR:
   902                  esc = ESC_BKSL
   903  
   904              # hexadecimal escape characters (3 chars)
   905              elif esc == ESC_BKSL and nch == 'x':
   906                  esc = ESC_HEX0
   907  
   908              # octal escape characters (3 chars)
   909              elif esc == ESC_BKSL and nch in string.octdigits:
   910                  esc = ESC_OCT1
   911  
   912              # generic escape characters (single char)
   913              elif esc == ESC_BKSL and nch in ('a', 'b', 'f', 'r', 'n', 't', 'v', '"', '\\'):
   914                  esc = ESC_ISTR
   915  
   916              # invalid escape sequence
   917              elif esc == ESC_BKSL:
   918                  raise SyntaxError('invalid escape character: ' + repr(nch))
   919  
   920              # normal characters, simply advance to the next character
   921              elif esc in (ESC_IDLE, ESC_ISTR):
   922                  pass
   923  
   924              # hexadecimal escape characters
   925              elif esc in (ESC_HEX0, ESC_HEX1) and nch.lower() in string.hexdigits:
   926                  esc = ESC_HEX1 if esc == ESC_HEX0 else ESC_ISTR
   927  
   928              # invalid hexadecimal character
   929              elif esc in (ESC_HEX0, ESC_HEX1):
   930                  raise SyntaxError('invalid hexdecimal character: ' + repr(nch))
   931  
   932              # octal escape characters
   933              elif esc in (ESC_OCT1, ESC_OCT2) and nch.lower() in string.octdigits:
   934                  esc = ESC_OCT2 if esc == ESC_OCT1 else ESC_ISTR
   935  
   936              # at most 3 octal digits
   937              elif esc in (ESC_OCT1, ESC_OCT2):
   938                  esc = ESC_ISTR
   939  
   940              # illegal state, should not happen
   941              else:
   942                  raise RuntimeError('illegal state: %d' % esc)
   943  
   944          # check for the last argument
   945          if pos is None:
   946              return cls(cmd, args)
   947  
   948          # add the last argument and build the command
   949          args.append(vstr[pos:].strip())
   950          return cls(cmd, args)
   951  
   952  class Expression:
   953      pos: int
   954      src: str
   955  
   956      def __init__(self, src: str):
   957          self.pos = 0
   958          self.src = src
   959  
   960      @property
   961      def _ch(self) -> str:
   962          return self.src[self.pos]
   963  
   964      @property
   965      def _eof(self) -> bool:
   966          return self.pos >= len(self.src)
   967  
   968      def _rch(self) -> str:
   969          pos, self.pos = self.pos, self.pos + 1
   970          return self.src[pos]
   971  
   972      def _hex(self, ch: str) -> bool:
   973          if len(ch) == 1 and ch[0] == '0':
   974              return self._ch.lower() == 'x'
   975          elif len(ch) <= 1 or ch[1].lower() != 'x':
   976              return self._ch.isdigit()
   977          else:
   978              return self._ch in string.hexdigits
   979  
   980      def _int(self, ch: str) -> Token:
   981          while not self._eof and self._hex(ch):
   982              ch += self._rch()
   983          else:
   984              if ch.lower().startswith('0x'):
   985                  return Token.num(int(ch, 16))
   986              elif ch[0] == '0':
   987                  return Token.num(int(ch, 8))
   988              else:
   989                  return Token.num(int(ch))
   990  
   991      def _name(self, ch: str) -> Token:
   992          while not self._eof and (self._ch == '_' or self._ch.isalnum()):
   993              ch += self._rch()
   994          else:
   995              return Token.name(ch)
   996  
   997      def _read(self, ch: str) -> Token:
   998          if ch.isdigit():
   999              return self._int(ch)
  1000          elif ch.isidentifier():
  1001              return self._name(ch)
  1002          elif ch in ('*', '<', '>') and not self._eof and self._ch == ch:
  1003              return Token.punc(self._rch() * 2)
  1004          elif ch in ('+', '-', '*', '/', '%', '&', '|', '^', '~', '(', ')'):
  1005              return Token.punc(ch)
  1006          else:
  1007              raise SyntaxError('invalid character: ' + repr(ch))
  1008  
  1009      def _peek(self) -> Optional[Token]:
  1010          pos = self.pos
  1011          ret = self._next()
  1012          self.pos = pos
  1013          return ret
  1014  
  1015      def _next(self) -> Optional[Token]:
  1016          while not self._eof and self._ch.isspace():
  1017              self.pos += 1
  1018          else:
  1019              return Token.end() if self._eof else self._read(self._rch())
  1020  
  1021      def _grab(self, tk: Token, getvalue: Callable[[str], int]) -> int:
  1022          if tk.tag == TOKEN_NUM:
  1023              return tk.val
  1024          elif tk.tag == TOKEN_NAME:
  1025              return getvalue(tk.val)
  1026          else:
  1027              raise SyntaxError('integer or identifier expected, got ' + repr(tk))
  1028  
  1029      __pred__ = [
  1030          {'<<', '>>'},
  1031          {'|'},
  1032          {'^'},
  1033          {'&'},
  1034          {'+', '-'},
  1035          {'*', '/', '%'},
  1036          {'**'},
  1037      ]
  1038  
  1039      __binary__ = {
  1040          '+'  : lambda a, b: a + b,
  1041          '-'  : lambda a, b: a - b,
  1042          '*'  : lambda a, b: a * b,
  1043          '/'  : lambda a, b: a / b,
  1044          '%'  : lambda a, b: a % b,
  1045          '&'  : lambda a, b: a & b,
  1046          '^'  : lambda a, b: a ^ b,
  1047          '|'  : lambda a, b: a | b,
  1048          '<<' : lambda a, b: a << b,
  1049          '>>' : lambda a, b: a >> b,
  1050          '**' : lambda a, b: a ** b,
  1051      }
  1052  
  1053      def _eval(self, op: str, v1: int, v2: int) -> int:
  1054          return self.__binary__[op](v1, v2)
  1055  
  1056      def _nest(self, nest: int, getvalue: Callable[[str], int]) -> int:
  1057          ret = self._expr(0, nest + 1, getvalue)
  1058          ntk = self._next()
  1059  
  1060          # it must follows with a ')' operator
  1061          if ntk.tag != TOKEN_PUNC or ntk.val != ')':
  1062              raise SyntaxError('")" expected, got ' + repr(ntk))
  1063          else:
  1064              return ret
  1065  
  1066      def _unit(self, nest: int, getvalue: Callable[[str], int]) -> int:
  1067          tk = self._next()
  1068          tt, tv = tk.tag, tk.val
  1069  
  1070          # check for unary operators
  1071          if tt == TOKEN_NUM:
  1072              return tv
  1073          elif tt == TOKEN_NAME:
  1074              return getvalue(tv)
  1075          elif tt == TOKEN_PUNC and tv == '(':
  1076              return self._nest(nest, getvalue)
  1077          elif tt == TOKEN_PUNC and tv == '+':
  1078              return self._unit(nest, getvalue)
  1079          elif tt == TOKEN_PUNC and tv == '-':
  1080              return -self._unit(nest, getvalue)
  1081          elif tt == TOKEN_PUNC and tv == '~':
  1082              return ~self._unit(nest, getvalue)
  1083          else:
  1084              raise SyntaxError('integer, unary operator or nested expression expected, got ' + repr(tk))
  1085  
  1086      def _term(self, pred: int, nest: int, getvalue: Callable[[str], int]) -> int:
  1087          lv = self._expr(pred + 1, nest, getvalue)
  1088          tk = self._peek()
  1089  
  1090          # scan to the end
  1091          while True:
  1092              tt = tk.tag
  1093              tv = tk.val
  1094  
  1095              # encountered EOF
  1096              if tt == TOKEN_END:
  1097                  return lv
  1098  
  1099              # must be an operator here
  1100              if tt != TOKEN_PUNC:
  1101                  raise SyntaxError('operator expected, got ' + repr(tk))
  1102  
  1103              # check for the operator precedence
  1104              if tv not in self.__pred__[pred]:
  1105                  return lv
  1106  
  1107              # apply the operator
  1108              op = self._next().val
  1109              rv = self._expr(pred + 1, nest, getvalue)
  1110              lv = self._eval(op, lv, rv)
  1111              tk = self._peek()
  1112  
  1113      def _expr(self, pred: int, nest: int, getvalue: Callable[[str], int]) -> int:
  1114          if pred >= len(self.__pred__):
  1115              return self._unit(nest, getvalue)
  1116          else:
  1117              return self._term(pred, nest, getvalue)
  1118  
  1119      def eval(self, getvalue: Callable[[str], int]) -> int:
  1120          return self._expr(0, 0, getvalue)
  1121  
  1122  
  1123  class Instr:
  1124      ALIGN_WIDTH = 48
  1125      len   : int                     = NotImplemented
  1126      instr : Union[str, Instruction] = NotImplemented
  1127  
  1128      def size(self, _: int) -> int:
  1129          return self.len
  1130  
  1131      def formatted(self, pc: int) -> str:
  1132          raise NotImplementedError
  1133  
  1134      @staticmethod
  1135      def raw_formatted(bs: bytes, comm: str, pc: int) -> str:
  1136          t = '\t'
  1137          if bs:
  1138              for b in bs:
  1139                  t +='0x%02x, ' % b
  1140              # if len(bs)<Instr.ALIGN_WIDTH:
  1141              #     t += '\b' * (Instr.ALIGN_WIDTH - len(bs))
  1142          return '%s//%s%s' % (t, ('0x%08x ' % pc) if pc else ' ', comm)
  1143  
  1144  class RawInstr(Instr):
  1145      bs: bytes
  1146      def __init__(self, size: int, instr: str, bs: bytes):
  1147          self.len = size
  1148          self.instr = instr
  1149          self.bs = bs
  1150  
  1151      def formatted(self, _: int) -> str:
  1152          return '\t' + self.instr
  1153  
  1154      def raw_formatted(self, pc: int) -> str:
  1155          return Instr.raw_formatted(self.bs, self.instr, pc)
  1156  
  1157  class IntInstr(Instr):
  1158      comm: str
  1159      func: Callable[[], int]
  1160  
  1161      def __init__(self, size: int, func: Callable[[], int], comments: str = ''):
  1162          self.len = size
  1163          self.func = func
  1164          self.comm = comments
  1165  
  1166      @property
  1167      def raw_bytes(self):
  1168          return self.func().to_bytes(self.len, 'little')
  1169  
  1170      @property
  1171      def instr(self) -> str:
  1172          return Instruction.encode(self.func().to_bytes(self.len, 'little'), self.comm)
  1173  
  1174      def formatted(self, _: int) -> str:
  1175          return '\t' + self.instr
  1176  
  1177      def raw_formatted(self, pc: int) -> str:
  1178          return Instr.raw_formatted(self.func().to_bytes(self.len, 'little'), self.comm, pc)
  1179  
  1180  class X86Instr(Instr):
  1181      def __init__(self, instr: Instruction):
  1182          self.len = instr.size
  1183          self.instr = instr
  1184  
  1185      def resize(self, size: int) -> int:
  1186          self.len = size
  1187          return size
  1188  
  1189      def formatted(self, _: int) -> str:
  1190          return '\t' + self.instr.encoded
  1191  
  1192      def raw_formatted(self, pc: int) -> str:
  1193          return Instr.raw_formatted(self.instr._raw_instr(), str(self.instr), pc)
  1194  
  1195  class LabelInstr(Instr):
  1196      def __init__(self, name: str):
  1197          self.len = 0
  1198          self.instr = name
  1199  
  1200      def formatted(self, _: int) -> str:
  1201          if self.instr.isidentifier():
  1202              return self.instr + ':'
  1203          else:
  1204              return '_LB_%08x: // %s' % (hash(self.instr) & 0xffffffff, self.instr)
  1205  
  1206      def raw_formatted(self, pc: int) -> str:
  1207          return Instr.raw_formatted(None, str(self.instr), pc)
  1208  
  1209  class BranchInstr(Instr):
  1210      def __init__(self, instr: Instruction):
  1211          self.len = instr.size
  1212          self.instr = instr
  1213  
  1214      def formatted(self, _: int) -> str:
  1215          return '\t' + self.instr.encoded
  1216  
  1217      def raw_formatted(self, pc: int) -> str:
  1218          return Instr.raw_formatted(self.instr._raw_instr(), str(self.instr), pc)
  1219  
  1220  class CommentInstr(Instr):
  1221      def __init__(self, text: str):
  1222          self.len = 0
  1223          self.instr = '// ' + text
  1224  
  1225      def formatted(self, _: int) -> str:
  1226          return '\t' + self.instr
  1227  
  1228      def raw_formatted(self, pc: int) -> str:
  1229          return  Instr.raw_formatted(None, str(self.instr), None)
  1230  
  1231  class AlignmentInstr(Instr):
  1232      bits: int
  1233      fill: int
  1234  
  1235      def __init__(self, bits: int, fill: int = 0):
  1236          self.bits = bits
  1237          self.fill = fill
  1238  
  1239      def size(self, pc: int) -> int:
  1240          mask = (1 << self.bits) - 1
  1241          return (mask - (pc & mask) + 1) & mask
  1242  
  1243      def formatted(self, pc: int) -> str:
  1244          buf = bytes([self.fill]) * self.size(pc)
  1245          return '\t' + Instruction.encode(buf, '.p2align %d, 0x%02x' % (self.bits, self.fill))
  1246  
  1247      def raw_formatted(self, pc: int) -> str:
  1248          buf = bytes([self.fill]) * self.size(pc)
  1249          return Instr.raw_formatted(buf, '.p2align %d, 0x%02x' % (self.bits, self.fill), pc)
  1250  
  1251  REG_MAP = {
  1252      'x0'  : ('MOVD'  , 'R0'),
  1253      'x1'  : ('MOVD'  , 'R1'),
  1254      'x2'  : ('MOVD'  , 'R2'),
  1255      'x3'  : ('MOVD'  , 'R3'),
  1256      'x4'  : ('MOVD'  , 'R4'),
  1257      'x5'  : ('MOVD'  , 'R5'),
  1258      'x6'  : ('MOVD'  , 'R6'),
  1259      'x7'  : ('MOVD'  , 'R7'),
  1260      'd0'  : ('FMOVD' , 'F0'),
  1261      'd1'  : ('FMOVD' , 'F1'),
  1262      'd2'  : ('FMOVD' , 'F2'),
  1263      'd3'  : ('FMOVD' , 'F3'),
  1264      'd4'  : ('FMOVD' , 'F4'),
  1265      'd5'  : ('FMOVD' , 'F5'),
  1266      'd6'  : ('FMOVD' , 'F6'),
  1267      'd7'  : ('FMOVD' , 'F7'),
  1268  }
  1269  
  1270  class Counter:
  1271      value: int = 0
  1272  
  1273      @classmethod
  1274      def next(cls) -> int:
  1275          val, cls.value = cls.value, cls.value + 1
  1276          return val
  1277  
  1278  class BasicBlock:
  1279      maxsp: int
  1280      name: str
  1281      weak: bool
  1282      jmptab: bool
  1283      func: bool
  1284      body: List[Instr]
  1285      prevs: List['BasicBlock']
  1286      next: Optional['BasicBlock']
  1287      jump: Optional['BasicBlock']
  1288  
  1289      def __init__(self, name: str, weak: bool = True, jmptab: bool = False, func: bool = False):
  1290          self.maxsp = -1
  1291          self.body = []
  1292          self.prevs = []
  1293          self.name = name
  1294          self.weak = weak
  1295          self.next = None
  1296          self.jump = None
  1297          self.jmptab = jmptab
  1298          self.func = func
  1299  
  1300      def __repr__(self):
  1301          return '{BasicBlock %s}' % repr(self.name)
  1302  
  1303      @property
  1304      def last(self) -> Optional[Instr]:
  1305          return next((v for v in reversed(self.body) if not isinstance(v, CommentInstr)), None)
  1306  
  1307      def if_all_IntInstr_then_2_RawInstr(self):
  1308          is_table = False
  1309          instr_size = 0
  1310          for instr in self.body:
  1311              if isinstance(instr, IntInstr):
  1312                 if not is_table:
  1313                     instr_size = instr.len
  1314                 is_table = True
  1315                 if instr_size != instr.len:
  1316                     instr_size = 0
  1317                 continue
  1318              if isinstance(instr, AlignmentInstr):
  1319                 continue
  1320              if isinstance(instr, LabelInstr):
  1321                 continue
  1322              # others
  1323              return
  1324  
  1325          if not is_table:
  1326              return
  1327  
  1328          # .long or .quad
  1329          if instr_size == 8 or instr_size == 4:
  1330              return
  1331  
  1332          # All instrs are IntInstr, golang asm only suuport WORD and DWORD for arm. We need
  1333          # combine them as 4-bytes RawInstr and align block
  1334          nb = [] # new body
  1335          raw_buf = [];
  1336          comment = ''
  1337  
  1338          # first element is LabelInstr
  1339          for i in range(1, len(self.body)):
  1340              if isinstance(self.body[i], AlignmentInstr):
  1341                  if i != len(self.body) -1:
  1342                      raise RuntimeError("Not support p2algin in : %s" % self.name)
  1343                  continue
  1344  
  1345              raw_buf += self.body[i].raw_bytes
  1346              comment += '// ' + self.body[i].comm + '\n'
  1347  
  1348          align_size = len(raw_buf) % 4
  1349          if align_size != 0:
  1350              raw_buf += int(0).to_bytes(4 - align_size, 'little')
  1351  
  1352          if isinstance(self.body[0], LabelInstr):
  1353              nb.append(self.body[0])
  1354  
  1355          for i in range(0, len(raw_buf), 4):
  1356              buf = raw_buf[i: i + 4]
  1357              nb.append(RawInstr(len(buf), Instruction.encode(buf), buf))
  1358  
  1359          nb.append(CommentInstr(comment))
  1360  
  1361          if isinstance(self.body[-1:-1], AlignmentInstr):
  1362              nb.append(self.body[-1:-1])
  1363          self.body = nb
  1364  
  1365      def size_of(self, pc: int) -> int:
  1366          return functools.reduce(lambda p, v: p + v.size(pc + p), self.body, 0)
  1367  
  1368      def link_to(self, block: 'BasicBlock'):
  1369          self.next = block
  1370          block.prevs.append(self)
  1371  
  1372      def jump_to(self, block: 'BasicBlock'):
  1373          self.jump = block
  1374          block.prevs.append(self)
  1375  
  1376      @classmethod
  1377      def annonymous(cls) -> 'BasicBlock':
  1378          return cls('// bb.%d' % Counter.next(), weak = False)
  1379  
  1380  CLANG_JUMPTABLE_LABLE = 'LJTI'
  1381  
  1382  class CodeSection:
  1383      dead   : bool
  1384      export : bool
  1385      blocks : List[BasicBlock]
  1386      labels : Dict[str, BasicBlock]
  1387      jmptabs: Dict[str, List[BasicBlock]]
  1388      funcs  : Dict[str, Pcsp]
  1389      bsmap_ : Dict[str, int]
  1390  
  1391      def __init__(self):
  1392          self.dead   = False
  1393          self.labels = {}
  1394          self.export = False
  1395          self.blocks = [BasicBlock.annonymous()]
  1396          self.jmptabs = {}
  1397          self.funcs = {}
  1398          self.bsmap_ = {}
  1399  
  1400      @classmethod
  1401      def _dfs_jump_first(cls, bb: BasicBlock, visited: Dict[BasicBlock, bool], hook: Callable[[BasicBlock], bool]) -> bool:
  1402          if bb not in visited or not visited[bb]:
  1403              visited[bb] = True
  1404              if bb.jump and not cls._dfs_jump_first(bb.jump, visited, hook):
  1405                  return False
  1406              if bb.next and not cls._dfs_jump_first(bb.next, visited, hook):
  1407                  return False
  1408              return hook(bb)
  1409          else:
  1410              return True
  1411  
  1412      def get_jmptab(self, name: str) -> List[BasicBlock]:
  1413          return self.jmptabs.setdefault(name, [])
  1414  
  1415      def get_block(self, name: str) -> BasicBlock:
  1416          for block in self.blocks:
  1417              if block.name == name:
  1418                  return block
  1419  
  1420      @property
  1421      def block(self) -> BasicBlock:
  1422          return self.blocks[-1]
  1423  
  1424      @property
  1425      def instrs(self) -> Iterable[Instr]:
  1426          for block in self.blocks:
  1427              yield from block.body
  1428  
  1429      def _make(self, name: str, jmptab: bool = False, func: bool = False):
  1430          if func:
  1431          #NOTICE: if it is a function, always set func to be True
  1432              if (old := self.labels.get(name)) and (old.func != func):
  1433                  old.func = True
  1434          return self.labels.setdefault(name, BasicBlock(name, jmptab = jmptab, func = func))
  1435  
  1436      def _next(self, link: BasicBlock):
  1437          if self.dead:
  1438              self.dead = False
  1439          else:
  1440              self.block.link_to(link)
  1441  
  1442      def _decl(self, name: str, block: BasicBlock):
  1443          block.weak = False
  1444          block.body.append(LabelInstr(name))
  1445          self._next(block)
  1446          self.blocks.append(block)
  1447  
  1448      def _kill(self, name: str):
  1449          self.dead = True
  1450          self.block.link_to(self._make(name))
  1451  
  1452      def _split(self, jmp: BasicBlock):
  1453          self.jump = True
  1454          link = BasicBlock.annonymous()
  1455          self.labels[link.name] = link
  1456          self.block.link_to(link)
  1457          self.block.jump_to(jmp)
  1458          self.blocks.append(link)
  1459  
  1460      @staticmethod
  1461      def _mk_align(v: int) -> int:
  1462          if v & 15 == 0:
  1463              return v
  1464          else:
  1465              print('* warning: SP is not aligned with 16 bytes.', file = sys.stderr)
  1466              return (v + 15) & -16
  1467  
  1468      @staticmethod
  1469      def _is_spadj(ins: Instruction) -> bool:
  1470          return len(ins.instr.operands) == 3                         and \
  1471                 isinstance(ins.instr.operands[1], mcasm.mc.Register) and \
  1472                 isinstance(ins.instr.operands[2], int)               and \
  1473                 ins.instr.operands[1].name == 'RSP'
  1474  
  1475      @staticmethod
  1476      def _is_spmove(ins: Instruction, i: int) -> bool:
  1477          return len(ins.operands) == 2                and \
  1478                 isinstance(ins.operands[0], Register) and \
  1479                 isinstance(ins.operands[1], Register) and \
  1480                 ins.operands[i].reg == 'rsp'
  1481  
  1482      @staticmethod
  1483      def _is_rjump(ins: Optional[Instr]) -> bool:
  1484          return isinstance(ins, X86Instr) and ins.instr.is_branch_label
  1485  
  1486      def _find_label(self, name: str, adjs: Iterable[int], size: int = 0) -> int:
  1487          for adj, block in zip(adjs, self.blocks):
  1488              if block.name == name:
  1489                  return size
  1490              else:
  1491                  # find block size from cache
  1492                  v = self.bsmap_.get(block.name)
  1493                  if v is not None:
  1494                      size += v + adj
  1495                  else:
  1496                      block_size = block.size_of(size)
  1497                      size += block_size + adj
  1498                      self.bsmap_[block.name] = block_size
  1499          else:
  1500              raise SyntaxError('unresolved reference to name: ' + name)
  1501  
  1502      def _alloc_instr(self, instr: Instruction):
  1503          if not instr.is_branch_label:
  1504              self.block.body.append(X86Instr(instr))
  1505          else:
  1506              self.block.body.append(BranchInstr(instr))
  1507  
  1508      # it seems to not be able to specify stack aligment inside the Go ASM so we
  1509      # need to replace the aligned instructions with unaligned one if either of it's
  1510      # operand is an RBP relative addressing memory operand
  1511  
  1512      def _check_align(self, instr: Instruction) -> bool:
  1513          # TODO: check
  1514          return False
  1515  
  1516      def _check_split(self, instr: Instruction):
  1517          if instr.is_return:
  1518              self.dead = True
  1519  
  1520          elif instr.is_jmpq: # jmpq
  1521              # backtrace jump table from current block (BFS)
  1522              prevs = [self.block]
  1523              visited = set()
  1524              while len(prevs) > 0:
  1525                  curb = prevs.pop()
  1526                  if curb in visited:
  1527                      continue
  1528                  else:
  1529                      visited.add(curb)
  1530  
  1531                  # backtrace instructions
  1532                  for ins in reversed(curb.body):
  1533                      if isinstance(ins, X86Instr) and ins.instr.jmptab:
  1534                          self._split(self._make(ins.instr.jmptab, jmptab = True))
  1535                          return
  1536  
  1537                  if curb.prevs:
  1538                      prevs.extend(curb.prevs)
  1539  
  1540          elif instr.is_branch_label:
  1541              if instr.is_jmp:
  1542                  self._kill(instr.label_name)
  1543              
  1544              elif instr.is_invoke: # call
  1545                  fname = instr.label_name
  1546                  self._split(self._make(fname, func = True))
  1547  
  1548              else: # jeq, ja, jae ...
  1549                  self._split(self._make(instr.label_name))
  1550  
  1551      def _trace_block(self, bb: BasicBlock, pcsp: Optional[Pcsp]) -> int:
  1552          if (pcsp is not None):
  1553              if bb.name in self.funcs:
  1554                  # already traced
  1555                  pcsp = None
  1556              else:
  1557                  # continue tracing, update the pcsp
  1558                  # NOTICE: must mark pcsp at block entry because go only calculate delta value
  1559                  pcsp.pc = self.get(bb.name)
  1560                  if bb.func or pcsp.pc < pcsp.entry:
  1561                      # new func
  1562                      pcsp = Pcsp(pcsp.pc)
  1563                      self.funcs[bb.name] = pcsp
  1564  
  1565          if bb.maxsp == -1:
  1566              ret = self._trace_nocache(bb, pcsp)
  1567              return ret
  1568          elif bb.maxsp >= 0:
  1569              return bb.maxsp
  1570          else:
  1571              return 0
  1572  
  1573      def _trace_nocache(self, bb: BasicBlock, pcsp: Optional[Pcsp]) -> int:
  1574          bb.maxsp = -2
  1575  
  1576          # ## FIXME:
  1577          # if pcsp is None:
  1578          #     pcsp = Pcsp(0)
  1579  
  1580          # make a fake object just for reducing redundant checking
  1581          if pcsp:
  1582              pc0, sp0 = pcsp.pc, pcsp.sp
  1583  
  1584          maxsp, term = self._trace_instructions(bb, pcsp)
  1585  
  1586          # this is a terminating block
  1587          if term:
  1588              return maxsp
  1589  
  1590          # don't trace it's next block if it's an unconditional jump
  1591          a, b = 0, 0
  1592          if pcsp:
  1593              pc, sp = pcsp.pc, pcsp.sp
  1594  
  1595          if bb.jump:
  1596              if bb.jump.jmptab:
  1597                  cases = self.get_jmptab(bb.jump.name)
  1598                  for case in cases:
  1599                      nsp = self._trace_block(case, pcsp)
  1600                      if pcsp:
  1601                          pcsp.pc, pcsp.sp = pc, sp
  1602                      if nsp > a:
  1603                          a = nsp
  1604              else:
  1605                  a = self._trace_block(bb.jump, pcsp)
  1606                  if pcsp:
  1607                      pcsp.pc, pcsp.sp = pc, sp
  1608  
  1609          if bb.next:
  1610              b = self._trace_block(bb.next, pcsp)
  1611  
  1612          if pcsp:
  1613              pcsp.pc, pcsp.sp = pc0, sp0
  1614  
  1615          # select the maximum stack depth
  1616          bb.maxsp = maxsp + max(a, b)
  1617          return bb.maxsp
  1618  
  1619      def _trace_instructions(self, bb: BasicBlock, pcsp: Pcsp) -> Tuple[int, bool]:
  1620          cursp = 0
  1621          maxsp = 0
  1622          close = False
  1623  
  1624          # scan every instruction
  1625          for ins in bb.body:
  1626              diff = 0
  1627  
  1628              if isinstance(ins, X86Instr):
  1629                  name = ins.instr.mnemonic
  1630                  operands = ins.instr.instr.operands
  1631  
  1632                  # check for instructions
  1633                  if name == 'ret':
  1634                      close = True
  1635                  elif isinstance(operands[0], mcasm.mc.Register) and operands[0].name == 'SP':
  1636                      # print(ins.instr.asm_code)
  1637                      if name == 'add':
  1638                          diff = -self._mk_align(operands[2])
  1639                      elif name == 'sub':
  1640                          diff = self._mk_align(operands[2])
  1641                      elif name == 'stp':
  1642                          diff = -self._mk_align(operands[4] * 8)
  1643                      elif name == 'ldp':
  1644                          diff = -self._mk_align(operands[4] * 8)
  1645                      elif name == 'str':
  1646                          diff = -self._mk_align(operands[3])
  1647                      else:
  1648                          raise RuntimeError("An instruction adjsut sp but bot processed: %s" % ins.instr.asm_code)
  1649  
  1650                  cursp += diff
  1651  
  1652                  # update the max stack depth
  1653                  if cursp > maxsp:
  1654                      maxsp = cursp
  1655  
  1656              # update pcsp
  1657              if pcsp:
  1658                  pcsp.update(ins.size(pcsp.pc), diff)
  1659  
  1660          # trace successful
  1661          return maxsp, close
  1662  
  1663      def get(self, key: str) -> Optional[int]:
  1664          if key not in self.labels:
  1665              raise SyntaxError('unresolved reference to name: %s' % key)
  1666          else:
  1667              return self._find_label(key, itertools.repeat(0, len(self.blocks)))
  1668  
  1669      def has(self, key: str) -> bool:
  1670          return key in self.labels
  1671  
  1672      def emit(self, buf: bytes, comments: str = ''):
  1673          if not self.dead:
  1674              self.block.body.append(RawInstr(len(buf), Instruction.encode(buf, comments or buf.hex()), buf))
  1675  
  1676      def lazy(self, size: int, func: Callable[[], int], comments: str = ''):
  1677          if not self.dead:
  1678              self.block.body.append(IntInstr(size, func, comments))
  1679  
  1680      def label(self, name: str):
  1681          if name not in self.labels or self.labels[name].weak:
  1682              self._decl(name, self._make(name))
  1683          else:
  1684              raise SyntaxError('duplicated label: ' + name)
  1685  
  1686      def instr(self, instr: Instruction):
  1687          if not self.dead:
  1688              if self._check_align(instr):
  1689                  return
  1690              self._alloc_instr(instr)
  1691              self._check_split(instr)
  1692  
  1693      # @functools.cache
  1694      def stacksize(self, name: str) -> int:
  1695          if name not in self.labels:
  1696              raise SyntaxError('undefined function: ' + name)
  1697          else:
  1698              return self._trace_block(self.labels[name], None)
  1699  
  1700      # @functools.cache
  1701      def pcsp(self, name: str, entry: int) -> int:
  1702          if name not in self.labels:
  1703              raise SyntaxError('undefined function: ' + name)
  1704          else:
  1705              pcsp = Pcsp(entry)
  1706              self.labels[name].func = True
  1707              return self._trace_block(self.labels[name], pcsp)
  1708  
  1709      def debug(self, pos: int, inss: List[Instruction]):
  1710          def inject(bb: BasicBlock) -> bool:
  1711              if (not bb.func) and (bb.name not in self.funcs):
  1712                  return True
  1713              nonlocal pos
  1714              if pos >= len(bb.body):
  1715                  return
  1716              for ins in inss:
  1717                  bb.body.insert(pos, ins)
  1718                  pos += 1
  1719          visited = {}
  1720          for _, bb in self.labels.items():
  1721              CodeSection._dfs_jump_first(bb, visited, inject)
  1722      def debug(self):
  1723          for label, bb in self.labels.items():
  1724              print(label)
  1725              for v in bb.body:
  1726                  if isinstance(v, (X86Instr, BranchInstr)):
  1727                      print(v.instr.asm_code)
  1728  
  1729  STUB_NAME = '__native_entry__'
  1730  STUB_SIZE = 67
  1731  WITH_OFFS = os.getenv('ASM2ASM_DEBUG_OFFSET', '').lower() in ('1', 'yes', 'true')
  1732  
  1733  class Assembler:
  1734      out  : List[str]
  1735      subr : Dict[str, int]
  1736      code : CodeSection
  1737      vals : Dict[str, Union[str, int]]
  1738  
  1739      def __init__(self):
  1740          self.out  = []
  1741          self.subr = {}
  1742          self.vals = {}
  1743          self.code = CodeSection()
  1744  
  1745      def _get(self, v: str) -> int:
  1746          if v not in self.vals:
  1747              return self.code.get(v)
  1748          elif isinstance(self.vals[v], int):
  1749              return self.vals[v]
  1750          else:
  1751              ret = self.vals[v] = self._eval(self.vals[v])
  1752              return ret
  1753  
  1754      def _eval(self, v: str) -> int:
  1755          return Expression(v).eval(self._get)
  1756  
  1757      def _emit(self, v: bytes, cmd: str):
  1758          align_size = len(v) % 4
  1759          if align_size != 0:
  1760              v += int(0).to_bytes(4 - align_size, 'little')
  1761  
  1762          for i in range(0, len(v), 4):
  1763              self.code.emit(v[i:i + 4], '%s %d, %s' % (cmd, len(v[i:i + 4]), repr(v[i:i + 16])[1:]))
  1764  
  1765      def _limit(self, v: int, a: int, b: int) -> int:
  1766          if not (a <= v <= b):
  1767              raise SyntaxError('integer constant out of bound [%d, %d): %d' % (a, b, v))
  1768          else:
  1769              return v
  1770  
  1771      def _vfill(self, cmd: str, args: List[str]) -> Tuple[int, int]:
  1772          if len(args) == 1:
  1773              return self._limit(self._eval(args[0]), 1, 1 << 64), 0
  1774          elif len(args) == 2:
  1775              return self._limit(self._eval(args[0]), 1, 1 << 64), self._limit(self._eval(args[1]), 0, 255)
  1776          else:
  1777              raise SyntaxError(cmd + ' takes 1 ~ 2 arguments')
  1778  
  1779      def _bytes(self, cmd: str, args: List[str], low: int, high: int, size: int):
  1780          if len(args) != 1:
  1781              raise SyntaxError(cmd + ' takes exact 1 argument')
  1782          else:
  1783              self.code.lazy(size, lambda: self._limit(self._eval(args[0]), low, high) & high, '%s %s' % (cmd, args[0]))
  1784  
  1785      def _comment(self, msg: str):
  1786          self.code.blocks[-1].body.append(CommentInstr(msg))
  1787  
  1788      def _cmd_nop(self, _: List[str]):
  1789          pass
  1790  
  1791      def _cmd_set(self, args: List[str]):
  1792          if len(args) != 2:
  1793              raise SyntaxError(".set takes exact 2 argument")
  1794          elif not args[0].isidentifier():
  1795              raise SyntaxError(repr(args[0]) + " is not a valid identifier")
  1796          else:
  1797              key = args[0]
  1798              val = args[1]
  1799              self.vals[key] = val
  1800              self._comment('.set ' + ', '.join(args))
  1801              # special case: clang-generated jump tables are always like '{block}_{table}'
  1802              jt = val.find(CLANG_JUMPTABLE_LABLE)
  1803              if jt > 0:
  1804                  tab = self.code.get_jmptab(val[jt:])
  1805                  tab.append(self.code.get_block(val[:jt-1]))
  1806  
  1807      def _cmd_byte(self, args: List[str]):
  1808          self._bytes('.byte', args, -0x80, 0xff, 1)
  1809  
  1810      def _cmd_word(self, args: List[str]):
  1811          self._bytes('.word', args, -0x8000, 0xffff, 2)
  1812  
  1813      def _cmd_long(self, args: List[str]):
  1814          self._bytes('.long', args, -0x80000000, 0xffffffff, 4)
  1815  
  1816      def _cmd_quad(self, args: List[str]):
  1817          self._bytes('.quad', args, -0x8000000000000000, 0xffffffffffffffff, 8)
  1818  
  1819      def _cmd_ascii(self, args: List[str]):
  1820          if len(args) != 1:
  1821              raise SyntaxError('.ascii takes exact 1 argument')
  1822          else:
  1823              self._emit(args[0].encode('latin-1'), '.ascii')
  1824  
  1825      def _cmd_asciz(self, args: List[str]):
  1826          if len(args) != 1:
  1827              raise SyntaxError('.asciz takes exact 1 argument')
  1828          else:
  1829              self._emit(args[0].encode('latin-1') + b'\0', '.asciz')
  1830  
  1831      def _cmd_space(self, args: List[str]):
  1832          nb, fv = self._vfill('.space', args)
  1833          self._emit(bytes([fv] * nb), '.space')
  1834  
  1835      def _cmd_p2align(self, args: List[str]):
  1836          if len(args) == 1:
  1837              self.code.block.body.append(AlignmentInstr(self._eval(args[0])))
  1838          elif len(args) == 2:
  1839              self.code.block.body.append(AlignmentInstr(self._eval(args[0]), self._eval(args[1])))
  1840          else:
  1841              raise SyntaxError('.p2align takes 1 ~ 2 arguments')
  1842  
  1843      @functools.cached_property
  1844      def _commands(self) -> dict:
  1845          return {
  1846              '.set'                     : self._cmd_set,
  1847              '.int'                     : self._cmd_long,
  1848              '.long'                    : self._cmd_long,
  1849              '.byte'                    : self._cmd_byte,
  1850              '.quad'                    : self._cmd_quad,
  1851              '.word'                    : self._cmd_word,
  1852              '.hword'                   : self._cmd_word,
  1853              '.short'                   : self._cmd_word,
  1854              '.ascii'                   : self._cmd_ascii,
  1855              '.asciz'                   : self._cmd_asciz,
  1856              '.space'                   : self._cmd_space,
  1857              '.globl'                   : self._cmd_nop,
  1858              '.text'                    : self._cmd_nop,
  1859              '.file'                    : self._cmd_nop,
  1860              '.type'                    : self._cmd_nop,
  1861              '.p2align'                 : self._cmd_p2align,
  1862              '.align'                   : self._cmd_nop,
  1863              '.size'                    : self._cmd_nop,
  1864              '.section'                 : self._cmd_nop,
  1865              '.loh'                     : self._cmd_nop,
  1866              '.data_region'             : self._cmd_nop,
  1867              '.build_version'           : self._cmd_nop,
  1868              '.end_data_region'         : self._cmd_nop,
  1869              '.subsections_via_symbols' : self._cmd_nop,
  1870              # linux-gnu
  1871              '.xword'                   :self._cmd_nop,
  1872          }
  1873  
  1874      @staticmethod
  1875      def _is_rip_relative(op: Operand) -> bool:
  1876          return isinstance(op, Memory) and \
  1877                 op.base is not None    and \
  1878                 op.base.reg == 'rip'   and \
  1879                 op.index is None       and \
  1880                 isinstance(op.disp, Reference)
  1881  
  1882      @staticmethod
  1883      def _remove_comments(line: str, *, st: str = 'normal') -> str:
  1884          for i, ch in enumerate(line):
  1885              if   st == 'normal' and ch == '/'        : st = 'slcomm'
  1886              elif st == 'normal' and ch == '\"'       : st = 'string'
  1887              # elif st == 'normal' and ch in ('#', ';') : return line[:i]
  1888              elif st == 'normal' and ch in (';')      : return line[:i]
  1889              elif st == 'slcomm' and ch == '/'        : return line[:i - 1]
  1890              elif st == 'slcomm'                      : st = 'normal'
  1891              elif st == 'string' and ch == '\"'       : st = 'normal'
  1892              elif st == 'string' and ch == '\\'       : st = 'escape'
  1893              elif st == 'escape'                      : st = 'string'
  1894          else:
  1895              return line
  1896  
  1897      @staticmethod
  1898      def _replace_adrp_line(line: str) -> str:
  1899          if 'adrp' in line:
  1900              line = line.replace('adrp', 'adr').replace('@PAGE', '')
  1901          return line
  1902  
  1903      @staticmethod
  1904      def _replace_adrp(src: List[str]) -> List[str]:
  1905          back_label_count = 0
  1906          adrp_label_map = {}
  1907          new_src = []
  1908          for line in src:
  1909              line = Assembler._remove_comments(line)
  1910              line = line.strip()
  1911  
  1912              if not line:
  1913                  continue
  1914              # is instructions
  1915              if line[-1] != ':' and line[0] != '.':
  1916                  instr = Instruction(line, back_label_count)
  1917                  if instr.ADRP_label:
  1918                      back_label_count += 1
  1919                      new_src.append(instr.asm_code)
  1920                      new_src.append(instr.back_label + ':')
  1921                      if instr.text_label in adrp_label_map:
  1922                          adrp_label_map[instr.text_label] += [(instr.ADRP_label, instr.ADR_instr, instr.back_label)]
  1923                      else:
  1924                          adrp_label_map[instr.text_label] = [(instr.ADRP_label, instr.ADR_instr, instr.back_label)]
  1925                  else:
  1926                      new_src.append(line)
  1927              else:
  1928                  new_src.append(line)
  1929  
  1930          nn_src = []
  1931  
  1932          for line in new_src:
  1933              if line[-1] == ':': # is label
  1934                  if line[:-1] in adrp_label_map:
  1935                      for item in adrp_label_map[line[:-1]]:
  1936                          nn_src.append(item[0] + ':')       # label that adrp will jump to
  1937                          nn_src.append(item[1])             # adr to get really symbol address
  1938                          nn_src.append('b ' + item[2])      # jump back to adrp next instruction
  1939              nn_src.append(line)
  1940  
  1941          return nn_src
  1942  
  1943      def _parse(self, src: List[str]):
  1944          # src = self._replace_adrp(o_src)
  1945  
  1946          for line in src:
  1947              line = Assembler._remove_comments(line)
  1948              line = line.strip()
  1949  
  1950              # skip empty lines
  1951              if not line:
  1952                  continue
  1953  
  1954              # labels, resolve the offset
  1955              if line[-1] == ':':
  1956                  self.code.label(line[:-1])
  1957                  continue
  1958  
  1959              # instructions
  1960              if line[0] != '.':
  1961                  line = self._replace_adrp_line(line)
  1962                  self.code.instr(Instruction(line, 0))
  1963                  continue
  1964  
  1965              # parse the command
  1966              cmd = Command.parse(line)
  1967              func = self._commands.get(cmd.cmd)
  1968  
  1969              # handle the command
  1970              if func is not None:
  1971                  func(cmd.args)
  1972              else:
  1973                  raise SyntaxError('invalid assembly command: ' + cmd.cmd)
  1974  
  1975      def _reloc(self, rip: int = 0):
  1976          for block in self.code.blocks:
  1977              for instr in block.body:
  1978                  rip += self._reloc_one(instr, rip)
  1979  
  1980      def _reloc_one(self, instr: Instr, rip: int) -> int:
  1981          if not isinstance(instr, (X86Instr, BranchInstr)):
  1982              return instr.size(rip)
  1983          elif instr.instr.need_reloc:
  1984              return self._reloc_branch(instr.instr, rip)
  1985          else:
  1986              return instr.resize(self._reloc_normal(instr.instr, rip))
  1987  
  1988      def _reloc_branch(self, instr: Instruction, rip: int) -> int:
  1989          label = instr.label_name
  1990          if label is None:
  1991              raise RuntimeError('cannnot found label name: %s' % instr.asm_code)
  1992          if instr.mnemonic == 'adr' and label == 'Ltmp0':
  1993              instr.set_label_offset(-4)
  1994          else:
  1995              instr.set_label_offset(self.code.get(label)- rip - instr.size)
  1996  
  1997          return instr.size
  1998  
  1999      def _reloc_normal(self, instr: Instruction, rip: int) -> int:
  2000          if instr.need_reloc:
  2001              raise SyntaxError('unresolved instruction when relocation: ' + instr.asm_code)
  2002          return instr.size
  2003  
  2004      def _LE_4bytes_IntIntr_2_RawIntr(self):
  2005          for block in self.code.blocks:
  2006              block.if_all_IntInstr_then_2_RawInstr()
  2007  
  2008      def _declare(self, protos: PrototypeMap):
  2009          if OUTPUT_RAW:
  2010              self._declare_body_raw()
  2011          else:
  2012              name = next(iter(protos))
  2013              self._declare_body(name[1:])
  2014          self._declare_functions(protos)
  2015  
  2016      def _declare_body(self, name: str):
  2017          size = self.code.stacksize(name)
  2018          gosize = 0 if size < 16 else size-16
  2019          self.out.append('TEXT ·_%s_entry__(SB), NOSPLIT, $%d' % (name, gosize))
  2020          self.out.append('\tNO_LOCAL_POINTERS')
  2021          # get current PC
  2022          self.out.append('\tWORD $0x100000a0 // adr x0, .+20')
  2023          # self.out.append('\t'+Instruction('add sp, sp, #%d' % size).encoded)
  2024          self.out.append('\tMOVD R0, ret(FP)')
  2025          self.out.append('\tRET')
  2026          self._LE_4bytes_IntIntr_2_RawIntr()
  2027          self._reloc()
  2028  
  2029          # instruction buffer
  2030          pc = 0
  2031          ins = self.code.instrs
  2032  
  2033          for v in ins:
  2034              self.out.append(('// +%d\n' % pc if WITH_OFFS else '') + v.formatted(pc))
  2035              pc += v.size(pc)
  2036  
  2037      def _declare_body_raw(self):
  2038          self._reloc()
  2039  
  2040          # instruction buffer
  2041          pc = 0
  2042          ins = self.code.instrs
  2043  
  2044          # dump every instruction
  2045          for v in ins:
  2046              self.out.append(v.raw_formatted(pc))
  2047              pc += v.size(pc)
  2048  
  2049      def _declare_function(self, name: str, proto: Prototype):
  2050          offs = 0
  2051          subr = name[1:]
  2052          addr = self.code.get(subr)
  2053          self.subr[subr] = addr
  2054          size = self.code.stacksize(subr)
  2055  
  2056          m_size = size + 64
  2057          # rsp_sub_size = size + 16
  2058  
  2059          if OUTPUT_RAW:
  2060              return
  2061  
  2062          # function header and stack checking
  2063          self.out.append('')
  2064          # frame size is 16 to store x29 and x30
  2065          # self.out.append('TEXT ·%s(SB), NOSPLIT | NOFRAME, $0-%d' % (name, proto.argspace))
  2066          self.out.append('TEXT ·%s(SB), NOSPLIT, $%d-%d' % (name, 0, proto.argspace))
  2067          self.out.append('\tNO_LOCAL_POINTERS')
  2068  
  2069          # add stack check if needed
  2070          if m_size != 0:
  2071              self.out.append('')
  2072              self.out.append('_entry:')
  2073              self.out.append('\tMOVD 16(g), R16')
  2074              if size > 0:
  2075               if size < (0x1 << 12) - 1:
  2076                   self.out.append('\tSUB $%d, RSP, R17' % (m_size))
  2077               elif size < (0x1 << 16) - 1:
  2078                   self.out.append('\tMOVD $%d, R17' % (m_size))
  2079                   self.out.append('\tSUB R17, RSP, R17')
  2080               else:
  2081                   raise RuntimeError('too large stack size: %d' % (m_size))
  2082               self.out.append('\tCMP  R16, R17')
  2083              else:
  2084               self.out.append('\tCMP R16, RSP')
  2085              self.out.append('\tBLS  _stack_grow')
  2086  
  2087          # function name
  2088          self.out.append('')
  2089          self.out.append('%s:' % subr)
  2090  
  2091          # self.out.append('\tMOVD.W R30, -16(RSP)')
  2092          # self.out.append('\tMOVD R29, -8(RSP)')
  2093          # self.out.append('\tSUB $8, RSP, R29')
  2094  
  2095          # intialize all the arguments
  2096          for arg in proto.args:
  2097              offs += arg.size
  2098              op, reg = REG_MAP[arg.creg.reg]
  2099              self.out.append('\t%s %s+%d(FP), %s' % (op, arg.name, offs - arg.size, reg))
  2100  
  2101  
  2102          # Go ASM completely ignores the offset of the JMP instruction,
  2103          # so we need to use indirect jumps instead for tail-call elimination
  2104          
  2105          # LEA and JUMP
  2106          self.out.append('\tMOVD ·_subr_%s(SB), R11' % (subr))
  2107          self.out.append('\tWORD $0x1000005e // adr x30, .+8')
  2108          self.out.append('\tJMP (R11)')
  2109          # self.out.append('\tCALL ·_%s_entry__(SB)  // %s' % (subr, subr))
  2110          
  2111          # normal functions, call the real function, and return the result
  2112          if proto.retv is not None:
  2113              self.out.append('\t%s, %s+%d(FP)' % (' '.join(REG_MAP[proto.retv.creg.reg]), proto.retv.name, offs))
  2114          # Restore LR and Frame Pointer
  2115          # self.out.append('\tLDP -8(RSP), (R29, R30)')
  2116          # self.out.append('\tADD $16, RSP')
  2117          
  2118          self.out.append('\tRET')
  2119  
  2120          # add stack growing if needed
  2121          if m_size != 0:
  2122              self.out.append('')
  2123              self.out.append('_stack_grow:')
  2124              self.out.append('\tMOVD R30, R3')
  2125              self.out.append('\tCALL runtime·morestack_noctxt<>(SB)')
  2126              self.out.append('\tJMP  _entry')
  2127  
  2128      def _declare_functions(self, protos: PrototypeMap):
  2129          for name, proto in sorted(protos.items()):
  2130              if name[0] == '_':
  2131                  self._declare_function(name, proto)
  2132              else:
  2133                  raise SyntaxError('function prototype must have a "_" prefix: ' + repr(name))
  2134  
  2135      def parse(self, src: List[str], proto: PrototypeMap):
  2136          # self.code.instr(Instruction('adr x0, .'))
  2137          # self.code.instr(Instruction('add sp, sp, #%d'%self.code.stacksize(name)))
  2138          # self.code.instr(Instruction('ret'))
  2139          # cmd = Command.parse(".p2align 4")
  2140          # func = self._commands.get(cmd.cmd)
  2141          # func(cmd.args)
  2142  
  2143          self._parse(src)
  2144          self._declare(proto)
  2145  
  2146  GOOS = {
  2147      'aix',
  2148      'android',
  2149      'darwin',
  2150      'dragonfly',
  2151      'freebsd',
  2152      'hurd',
  2153      'illumos',
  2154      'js',
  2155      'linux',
  2156      'nacl',
  2157      'netbsd',
  2158      'openbsd',
  2159      'plan9',
  2160      'solaris',
  2161      'windows',
  2162      'zos',
  2163  }
  2164  
  2165  GOARCH = {
  2166      '386',
  2167      'amd64',
  2168      'amd64p32',
  2169      'arm',
  2170      'armbe',
  2171      'arm64',
  2172      'arm64be',
  2173      'ppc64',
  2174      'ppc64le',
  2175      'mips',
  2176      'mipsle',
  2177      'mips64',
  2178      'mips64le',
  2179      'mips64p32',
  2180      'mips64p32le',
  2181      'ppc',
  2182      'riscv',
  2183      'riscv64',
  2184      's390',
  2185      's390x',
  2186      'sparc',
  2187      'sparc64',
  2188      'wasm',
  2189  }
  2190  
  2191  def make_subr_filename(name: str) -> str:
  2192      name = os.path.basename(name)
  2193      base = os.path.splitext(name)[0].rsplit('_', 2)
  2194  
  2195      # construct the new name
  2196      if base[-1] in GOOS:
  2197          return '%s_subr_%s.go' % ('_'.join(base[:-1]), base[-1])
  2198      elif base[-1] not in GOARCH:
  2199          return '%s_subr.go' % '_'.join(base)
  2200      elif len(base) > 2 and base[-2] in GOOS:
  2201          return '%s_subr_%s_%s.go' % ('_'.join(base[:-2]), base[-2], base[-1])
  2202      else:
  2203          return '%s_subr_%s.go' % ('_'.join(base[:-1]), base[-1])
  2204  
  2205  def parse_args():
  2206      parser = argparse.ArgumentParser(description='Convert llvm asm to golang asm.')
  2207      parser.add_argument('proto_file', type=str, help = 'The go file that declares go functions')
  2208      parser.add_argument('asm_file', type=str, nargs='+', help = 'The llvm assembly file')
  2209      parser.add_argument('-r', default=False, action='store_true', help = 'Ture: output as raw; default is False')
  2210      return parser.parse_args()
  2211  
  2212  def main():
  2213      src = []
  2214      args = parse_args()
  2215  
  2216      # check if optional flag is enabled
  2217      global OUTPUT_RAW
  2218      OUTPUT_RAW = False
  2219      if args.r:
  2220          OUTPUT_RAW = True
  2221  
  2222      proto_name = os.path.splitext(args.proto_file)[0]
  2223  
  2224      # parse the prototype
  2225      with open(proto_name + '.go', 'r', newline = None) as fp:
  2226          pkg, proto = PrototypeMap.parse(fp.read())
  2227  
  2228      # read all the sources, and combine them together
  2229      for fn in args.asm_file:
  2230          with open(fn, 'r', newline = None) as fp:
  2231              src.extend(fp.read().splitlines())
  2232  
  2233      asm = Assembler()
  2234  
  2235      # convert the original sources
  2236      if OUTPUT_RAW:
  2237          asm.out.append('// +build arm64')
  2238          asm.out.append('// Code generated by asm2asm, DO NOT EDIT.')
  2239          asm.out.append('')
  2240          asm.out.append('package %s' % pkg)
  2241          asm.out.append('')
  2242          ## native text
  2243          asm.out.append('var Text%s = []byte{' % STUB_NAME)
  2244      else:
  2245          asm.out.append('// +build !noasm !appengine')
  2246          asm.out.append('// Code generated by asm2asm, DO NOT EDIT.')
  2247          asm.out.append('')
  2248          asm.out.append('#include "go_asm.h"')
  2249          asm.out.append('#include "funcdata.h"')
  2250          asm.out.append('#include "textflag.h"')
  2251          asm.out.append('')
  2252  
  2253      asm.parse(src, proto)
  2254  
  2255      if OUTPUT_RAW:
  2256          asrc = proto_name[:proto_name.rfind('_')] + '_text_arm.go'
  2257      else:
  2258          asrc = proto_name + '.s'
  2259  
  2260      # save the converted result
  2261      with open(asrc, 'w')  as fp:
  2262          for line in asm.out:
  2263              print(line, file = fp)
  2264          if OUTPUT_RAW:
  2265              print('}', file = fp)
  2266  
  2267      # calculate the subroutine stub file name
  2268      subr = make_subr_filename(args.proto_file)
  2269      subr = os.path.join(os.path.dirname(args.proto_file), subr)
  2270  
  2271      # save the compiled code stub
  2272      with open(subr, 'w') as fp:
  2273          print('// +build !noasm !appengine', file = fp)
  2274          print('// Code generated by asm2asm, DO NOT EDIT.', file = fp)
  2275          print(file = fp)
  2276          print('package %s' % pkg, file = fp)
  2277  
  2278          # also save the actual function addresses if any
  2279          if not asm.subr:
  2280              return
  2281  
  2282          if OUTPUT_RAW:
  2283              print(file = fp)
  2284              print('import (\n\t`github.com/bytedance/sonic/loader`\n)', file = fp)
  2285  
  2286              # dump every entry for all functions
  2287              print(file = fp)
  2288              print('const (', file = fp)
  2289              for name in asm.code.funcs.keys():
  2290                  addr = asm.code.get(name)
  2291                  if addr is not None:
  2292                      print(f'    _entry_{name} = %d' % addr, file = fp)
  2293              print(')', file = fp)
  2294  
  2295              # dump max stack depth for all functions
  2296              print(file = fp)
  2297              print('const (', file = fp)
  2298              for name in asm.code.funcs.keys():
  2299                  print('    _stack_%s = %d' % (name, asm.code.stacksize(name)), file = fp)
  2300              print(')', file = fp)
  2301  
  2302              # dump every text size for all functions
  2303              print(file = fp)
  2304              print('const (', file = fp)
  2305              for name, pcsp in asm.code.funcs.items():
  2306                  if pcsp is not None:
  2307                      # print(f'before {name} optimize {pcsp}')
  2308                      pcsp.optimize()
  2309                      # print(f'after {name} optimize {pcsp}')
  2310                      print(f'    _size_{name} = %d' % (pcsp.maxpc - pcsp.entry), file = fp)
  2311              print(')', file = fp)
  2312  
  2313              # dump every pcsp for all functions
  2314              print(file = fp)
  2315              print('var (', file = fp)
  2316              for name, pcsp in asm.code.funcs.items():
  2317                  if pcsp is not None:
  2318                      print(f'    _pcsp_{name} = %s' % pcsp, file = fp)
  2319              print(')', file = fp)
  2320  
  2321              # insert native entry info
  2322              print(file = fp)
  2323              print('var Funcs = []loader.CFunc{', file = fp)
  2324              print('    {"%s", 0, %d, 0, nil},' % (STUB_NAME, STUB_SIZE), file = fp)
  2325              # dump every native function info for all functions
  2326              for name in asm.code.funcs.keys():
  2327                  print('    {"%s", _entry_%s, _size_%s, _stack_%s, _pcsp_%s},' % (name, name, name, name, name), file = fp)
  2328              print('}', file = fp)
  2329  
  2330          else:
  2331              # native entry for entry function
  2332              print(file = fp)
  2333              print('//go:nosplit', file = fp)
  2334              print('//go:noescape', file = fp)
  2335              print('//goland:noinspection ALL', file = fp)
  2336              for name, entry in asm.subr.items():
  2337                  print('func _%s_entry__() uintptr' % name, file = fp)
  2338              
  2339              # dump exported function entry for exported functions
  2340              print(file = fp)
  2341              print('var (', file = fp)
  2342              mlen = max(len(s) for s in asm.subr)
  2343              for name, entry in asm.subr.items():
  2344                  print('    _subr_%s uintptr = _%s_entry__() + %d' % (name.ljust(mlen, ' '), name, entry), file = fp)
  2345                  # print('    _subr_%s uintptr = %d' % (name.ljust(mlen, ' '), entry), file = fp)
  2346              print(')', file = fp)
  2347  
  2348              # dump max stack depth for exported functions
  2349              print(file = fp)
  2350              print('const (', file = fp)
  2351              for name in asm.subr.keys():
  2352                  print('    _stack_%s = %d' % (name, asm.code.stacksize(name)), file = fp)
  2353              print(')', file = fp)
  2354  
  2355              # assign subroutine offsets to '_' to mute the "unused" warnings
  2356              print(file = fp)
  2357              print('var (', file = fp)
  2358              for name in asm.subr:
  2359                  print('    _ = _subr_%s' % name, file = fp)
  2360              print(')', file = fp)
  2361  
  2362              # dump every constant
  2363              print(file = fp)
  2364              print('const (', file = fp)
  2365              for name in asm.subr:
  2366                  print('    _ = _stack_%s' % name, file = fp)
  2367              else:
  2368                  print(')', file = fp)
  2369  
  2370  if __name__ == '__main__':
  2371      main()