github.com/goshafaq/sonic@v0.0.0-20231026082336-871835fb94c6/tools/asm2asm/asm2asm.py (about)

     1  #!/usr/bin/env python3
     2  # -*- coding: utf-8 -*-
     3  
     4  import os
     5  import sys
     6  import string
     7  import itertools
     8  import functools
     9  
    10  from typing import Any
    11  from typing import Dict
    12  from typing import List
    13  from typing import Type
    14  from typing import Tuple
    15  from typing import Union
    16  from typing import Callable
    17  from typing import Iterable
    18  from typing import Optional
    19  
    20  from peachpy import x86_64
    21  from peachpy.x86_64 import generic
    22  from peachpy.x86_64 import XMMRegister
    23  from peachpy.x86_64.operand import is_rel32
    24  from peachpy.x86_64.operand import MemoryOperand
    25  from peachpy.x86_64.operand import MemoryAddress
    26  from peachpy.x86_64.operand import RIPRelativeOffset
    27  from peachpy.x86_64.instructions import Instruction as PInstr
    28  from peachpy.x86_64.instructions import BranchInstruction
    29  
    30  ### Instruction Parser (GAS Syntax) ###
    31  
    32  class Label:
    33      name: str
    34      offs: Optional[int]
    35  
    36      def __init__(self, name: str):
    37          self.name = name
    38          self.offs = None
    39  
    40      def __str__(self):
    41          return self.name
    42  
    43      def __repr__(self):
    44          if self.offs is None:
    45              return '{LABEL %s (unresolved)}' % self.name
    46          else:
    47              return '{LABEL %s (offset: %d)}' % (self.name, self.offs)
    48  
    49      def resolve(self, offs: int):
    50          self.offs = offs
    51  
    52  class Index:
    53      base  : 'Register'
    54      scale : int
    55  
    56      def __init__(self, base: 'Register', scale: int = 1):
    57          self.base  = base
    58          self.scale = scale
    59  
    60      def __str__(self):
    61          if self.scale == 1:
    62              return ',%s' % self.base
    63          elif self.scale >= 2:
    64              return ',%s,%d' % (self.base, self.scale)
    65          else:
    66              raise RuntimeError('invalid parser state: invalid scale')
    67  
    68      def __repr__(self):
    69          if self.scale == 1:
    70              return repr(self.base)
    71          elif self.scale >= 2:
    72              return '%d * %r' % (self.scale, self.base)
    73          else:
    74              raise RuntimeError('invalid parser state: invalid scale')
    75  
    76  class Memory:
    77      base  : Optional['Register']
    78      disp  : Optional['Displacement']
    79      index : Optional[Index]
    80  
    81      def __init__(self, base: Optional['Register'], disp: Optional['Displacement'], index: Optional[Index]):
    82          self.base  = base
    83          self.disp  = disp
    84          self.index = index
    85          self._validate()
    86  
    87      def __str__(self):
    88          return '%s(%s%s)' % (
    89              '' if self.disp  is None else self.disp,
    90              '' if self.base  is None else self.base,
    91              '' if self.index is None else self.index
    92          )
    93  
    94      def __repr__(self):
    95          return '{MEM %r%s%s}' % (
    96              '' if self.base  is None else self.base,
    97              '' if self.index is None else ' + ' + repr(self.index),
    98              '' if self.disp  is None else ' + ' + repr(self.disp)
    99          )
   100  
   101      def _validate(self):
   102          if self.base is None and self.index is None:
   103              raise SyntaxError('either base or index must be specified')
   104  
   105  class Register:
   106      reg: str
   107  
   108      def __init__(self, reg: str):
   109          self.reg = reg.lower()
   110  
   111      def __str__(self):
   112          return '%' + self.reg
   113  
   114      def __repr__(self):
   115          return '{REG %s}' % self.reg
   116  
   117      @functools.cached_property
   118      def native(self) -> x86_64.registers.Register:
   119          if self.reg == 'rip':
   120              raise SyntaxError('%rip is not directly accessible')
   121          else:
   122              return getattr(x86_64.registers, self.reg)
   123  
   124  class Immediate:
   125      val: int
   126      ref: str
   127  
   128      def __init__(self, val: int):
   129          self.ref = ''
   130          self.val = val
   131  
   132      def __str__(self):
   133          return '$%d' % self.val
   134  
   135      def __repr__(self):
   136          return '{IMM bin:%s, oct:%s, dec:%d, hex:%s}' % (
   137              bin(self.val)[2:],
   138              oct(self.val)[2:],
   139              self.val,
   140              hex(self.val)[2:],
   141          )
   142  
   143  class Reference:
   144      ref: str
   145      disp: int
   146      off: Optional[int]
   147  
   148      def __init__(self, ref: str, disp: int = 0):
   149          self.ref = ref
   150          self.disp = disp
   151          self.off = None
   152  
   153      def __str__(self):
   154          if self.off is None:
   155              return self.ref
   156          else:
   157              return '$' + str(self.off)
   158  
   159      def __repr__(self):
   160          if self.off is None:
   161              return '{REF %s + %d (unresolved)}' % (self.ref, self.disp)
   162          else:
   163              return '{REF %s + %d (offset: %d)}' % (self.ref, self.disp, self.off)
   164  
   165      @property
   166      def offset(self) -> int:
   167          if self.off is None:
   168              raise SyntaxError('unresolved reference to ' + repr(self.ref))
   169          else:
   170              return self.off
   171  
   172      def resolve(self, off: int):
   173          self.off = self.disp + off
   174  
   175  Operand = Union[
   176      Label,
   177      Memory,
   178      Register,
   179      Immediate,
   180      Reference,
   181  ]
   182  
   183  Displacement = Union[
   184      Immediate,
   185      Reference,
   186  ]
   187  
   188  TOKEN_END  = 0
   189  TOKEN_REG  = 1
   190  TOKEN_IMM  = 2
   191  TOKEN_NUM  = 3
   192  TOKEN_NAME = 4
   193  TOKEN_PUNC = 5
   194  
   195  REGISTERS = {
   196      'rax'   , 'eax'   , 'ax'    , 'al'    , 'ah'   ,
   197      'rbx'   , 'ebx'   , 'bx'    , 'bl'    , 'bh'   ,
   198      'rcx'   , 'ecx'   , 'cx'    , 'cl'    , 'ch'   ,
   199      'rdx'   , 'edx'   , 'dx'    , 'dl'    , 'dh'   ,
   200      'rsi'   , 'esi'   , 'si'    , 'sil'   ,
   201      'rdi'   , 'edi'   , 'di'    , 'dil'   ,
   202      'rbp'   , 'ebp'   , 'bp'    , 'bpl'   ,
   203      'rsp'   , 'esp'   , 'sp'    , 'spl'   ,
   204      'r8'    , 'r8d'   , 'r8w'   , 'r8b'   ,
   205      'r9'    , 'r9d'   , 'r9w'   , 'r9b'   ,
   206      'r10'   , 'r10d'  , 'r10w'  , 'r10b'  ,
   207      'r11'   , 'r11d'  , 'r11w'  , 'r11b'  ,
   208      'r12'   , 'r12d'  , 'r12w'  , 'r12b'  ,
   209      'r13'   , 'r13d'  , 'r13w'  , 'r13b'  ,
   210      'r14'   , 'r14d'  , 'r14w'  , 'r14b'  ,
   211      'r15'   , 'r15d'  , 'r15w'  , 'r15b'  ,
   212      'mm0'   , 'mm1'   , 'mm2'   , 'mm3'   , 'mm4'   , 'mm5'   , 'mm6'   , 'mm7'   ,
   213      'xmm0'  , 'xmm1'  , 'xmm2'  , 'xmm3'  , 'xmm4'  , 'xmm5'  , 'xmm6'  , 'xmm7'  ,
   214      'xmm8'  , 'xmm9'  , 'xmm10' , 'xmm11' , 'xmm12' , 'xmm13' , 'xmm14' , 'xmm15' ,
   215      'xmm16' , 'xmm17' , 'xmm18' , 'xmm19' , 'xmm20' , 'xmm21' , 'xmm22' , 'xmm23' ,
   216      'xmm24' , 'xmm25' , 'xmm26' , 'xmm27' , 'xmm28' , 'xmm29' , 'xmm30' , 'xmm31' ,
   217      'ymm0'  , 'ymm1'  , 'ymm2'  , 'ymm3'  , 'ymm4'  , 'ymm5'  , 'ymm6'  , 'ymm7'  ,
   218      'ymm8'  , 'ymm9'  , 'ymm10' , 'ymm11' , 'ymm12' , 'ymm13' , 'ymm14' , 'ymm15' ,
   219      'ymm16' , 'ymm17' , 'ymm18' , 'ymm19' , 'ymm20' , 'ymm21' , 'ymm22' , 'ymm23' ,
   220      'ymm24' , 'ymm25' , 'ymm26' , 'ymm27' , 'ymm28' , 'ymm29' , 'ymm30' , 'ymm31' ,
   221      'zmm0'  , 'zmm1'  , 'zmm2'  , 'zmm3'  , 'zmm4'  , 'zmm5'  , 'zmm6'  , 'zmm7'  ,
   222      'zmm8'  , 'zmm9'  , 'zmm10' , 'zmm11' , 'zmm12' , 'zmm13' , 'zmm14' , 'zmm15' ,
   223      'zmm16' , 'zmm17' , 'zmm18' , 'zmm19' , 'zmm20' , 'zmm21' , 'zmm22' , 'zmm23' ,
   224      'zmm24' , 'zmm25' , 'zmm26' , 'zmm27' , 'zmm28' , 'zmm29' , 'zmm30' , 'zmm31' ,
   225      'rip'   ,
   226  }
   227  
   228  class Token:
   229      tag: int
   230      val: Union[int, str]
   231  
   232      def __init__(self, tag: int, val: Union[int, str]):
   233          self.val = val
   234          self.tag = tag
   235  
   236      @classmethod
   237      def end(cls):
   238          return cls(TOKEN_END, '')
   239  
   240      @classmethod
   241      def reg(cls, reg: str):
   242          return cls(TOKEN_REG, reg)
   243  
   244      @classmethod
   245      def imm(cls, imm: int):
   246          return cls(TOKEN_IMM, imm)
   247  
   248      @classmethod
   249      def num(cls, num: int):
   250          return cls(TOKEN_NUM, num)
   251  
   252      @classmethod
   253      def name(cls, name: str):
   254          return cls(TOKEN_NAME, name)
   255  
   256      @classmethod
   257      def punc(cls, punc: str):
   258          return cls(TOKEN_PUNC, punc)
   259  
   260      def __repr__(self):
   261          if self.tag == TOKEN_END:
   262              return '<END>'
   263          elif self.tag == TOKEN_REG:
   264              return '<REG %s>' % self.val
   265          elif self.tag == TOKEN_IMM:
   266              return '<IMM %d>' % self.val
   267          elif self.tag == TOKEN_NUM:
   268              return '<NUM %d>' % self.val
   269          elif self.tag == TOKEN_NAME:
   270              return '<NAME %s>' % repr(self.val)
   271          elif self.tag == TOKEN_PUNC:
   272              return '<PUNC %s>' % repr(self.val)
   273          else:
   274              return '<UNK:%d %r>' % (self.tag, self.val)
   275  
   276  class Tokenizer:
   277      pos: int
   278      src: str
   279  
   280      def __init__(self, src: str):
   281          self.pos = 0
   282          self.src = src
   283  
   284      @property
   285      def _ch(self) -> str:
   286          return self.src[self.pos]
   287  
   288      @property
   289      def _eof(self) -> bool:
   290          return self.pos >= len(self.src)
   291  
   292      def _rch(self) -> str:
   293          ret, self.pos = self.src[self.pos], self.pos + 1
   294          return ret
   295  
   296      def _rid(self, s: str, allow_dot: bool) -> str:
   297          while not self._eof and (self._ch == '_' or self._ch.isalnum() or (allow_dot and self._ch == '.')):
   298              s += self._rch()
   299          else:
   300              return s
   301  
   302      def _reg(self) -> Token:
   303          if self._eof:
   304              raise SyntaxError('unexpected EOF when parsing register names')
   305          else:
   306              return self._regx()
   307  
   308      def _imm(self) -> Token:
   309          if self._eof:
   310              raise SyntaxError('unexpected EOF when parsing immediate values')
   311          else:
   312              return self._immx(self._rch())
   313  
   314      def _regx(self) -> Token:
   315          nch = self._rch()
   316          reg = self._rid(nch, allow_dot = False).lower()
   317  
   318          # check for register names
   319          if reg not in REGISTERS:
   320              raise SyntaxError('invalid register: ' + reg)
   321          else:
   322              return Token.reg(reg)
   323  
   324      def _immv(self, ch: str) -> int:
   325          while not self._eof and self._ch in string.digits:
   326              ch += self._rch()
   327          else:
   328              return int(ch)
   329  
   330      def _immx(self, ch: str) -> Token:
   331          if ch.isdigit():
   332              return Token.imm(self._immv(ch))
   333          elif ch == '-':
   334              return Token.imm(-self._immv(self._rch()))
   335          else:
   336              raise SyntaxError('unexpected character when parsing immediate value: ' + ch)
   337  
   338      def _name(self, ch: str) -> Token:
   339          return Token.name(self._rid(ch, allow_dot = True))
   340  
   341      def _read(self, ch: str) -> Token:
   342          if ch == '%':
   343              return self._reg()
   344          elif ch == '$':
   345              return self._imm()
   346          elif ch == '-':
   347              return Token.num(-self._immv(self._rch()))
   348          elif ch == '+':
   349              return Token.num(self._immv(self._rch()))
   350          elif ch.isdigit():
   351              return Token.num(self._immv(ch))
   352          elif ch.isidentifier():
   353              return self._name(ch)
   354          elif ch in ('(', ')', ',', '*'):
   355              return Token.punc(ch)
   356          else:
   357              raise SyntaxError('invalid character: ' + repr(ch))
   358  
   359      def next(self) -> Token:
   360          while not self._eof and self._ch.isspace():
   361              self.pos += 1
   362          else:
   363              return Token.end() if self._eof else self._read(self._rch())
   364  
   365  class Instruction:
   366      comments: str
   367      mnemonic: str
   368      operands: List[Operand]
   369  
   370      def __init__(self, mnemonic: str, operands: List[Operand]):
   371          self.comments = ''
   372          self.operands = operands
   373          self.mnemonic = mnemonic.lower()
   374  
   375      def __str__(self):
   376          ops = ', '.join(map(str, self.operands))
   377          com = self.comments and '  /* %s */' % self.comments
   378  
   379          # ordinal instructions
   380          if not self.is_branch:
   381              return '%-12s %s%s' % (self.mnemonic, ops, com)
   382          elif len(self.operands) != 1:
   383              raise SyntaxError('invalid branch instruction: ' + self.mnemonic)
   384          elif isinstance(self.operands[0], Label):
   385              return '%-12s %s%s' % (self.mnemonic, ops, com)
   386          else:
   387              return '%-12s *%s%s' % (self.mnemonic, ops, com)
   388  
   389      def __repr__(self):
   390          return '{INSTR %s: %s%s}' % (
   391              self.mnemonic,
   392              ', '.join(map(repr, self.operands)),
   393              self.comments and ' (%s)' % self.comments
   394          )
   395  
   396      class Basic:
   397          @staticmethod
   398          def INT3(*args, **kwargs):
   399              return x86_64.INT(3, *args, **kwargs)
   400  
   401          @staticmethod
   402          def MOVQ(*args, **kwargs):
   403              if not any(isinstance(v, XMMRegister) for v in args):
   404                  return x86_64.MOV(*args, **kwargs)
   405              else:
   406                  return x86_64.MOVQ(*args, **kwargs)
   407  
   408      class BitShift:
   409          op: Type[PInstr]
   410  
   411          def __init__(self, op: Type[PInstr]):
   412              self.op = op
   413  
   414          def __call__(self, *args, **kwargs):
   415              if len(args) != 1:
   416                  return self.op(*args, **kwargs)
   417              else:
   418                  return self.op(*args, 1, **kwargs)
   419  
   420      class VectorCompare:
   421          fn: int
   422          op: Type[PInstr]
   423  
   424          def __init__(self, op: Type[PInstr], fn: int):
   425              self.fn = fn
   426              self.op = op
   427  
   428          def __call__(self, *args, **kwargs):
   429              return self.op(*args, self.fn, **kwargs)
   430  
   431      __instr_map__ = {
   432          'INT3'       : Basic.INT3,
   433          'SALB'       : BitShift(x86_64.SAL),
   434          'SALW'       : BitShift(x86_64.SAL),
   435          'SALL'       : BitShift(x86_64.SAL),
   436          'SALQ'       : BitShift(x86_64.SAL),
   437          'SARB'       : BitShift(x86_64.SAR),
   438          'SARW'       : BitShift(x86_64.SAR),
   439          'SARL'       : BitShift(x86_64.SAR),
   440          'SARQ'       : BitShift(x86_64.SAR),
   441          'SHLB'       : BitShift(x86_64.SHL),
   442          'SHLW'       : BitShift(x86_64.SHL),
   443          'SHLL'       : BitShift(x86_64.SHL),
   444          'SHLQ'       : BitShift(x86_64.SHL),
   445          'SHRB'       : BitShift(x86_64.SHR),
   446          'SHRW'       : BitShift(x86_64.SHR),
   447          'SHRL'       : BitShift(x86_64.SHR),
   448          'SHRQ'       : BitShift(x86_64.SHR),
   449          'MOVQ'       : Basic.MOVQ,
   450          'CBTW'       : x86_64.CBW,
   451          'CWTL'       : x86_64.CWDE,
   452          'CLTQ'       : x86_64.CDQE,
   453          'MOVZBW'     : x86_64.MOVZX,
   454          'MOVZBL'     : x86_64.MOVZX,
   455          'MOVZWL'     : x86_64.MOVZX,
   456          'MOVZBQ'     : x86_64.MOVZX,
   457          'MOVZWQ'     : x86_64.MOVZX,
   458          'MOVSBW'     : x86_64.MOVSX,
   459          'MOVSBL'     : x86_64.MOVSX,
   460          'MOVSWL'     : x86_64.MOVSX,
   461          'MOVSBQ'     : x86_64.MOVSX,
   462          'MOVSWQ'     : x86_64.MOVSX,
   463          'MOVSLQ'     : x86_64.MOVSXD,
   464          'MOVABSQ'    : x86_64.MOV,
   465          'VCMPEQPS'   : VectorCompare(x86_64.VCMPPS, 0x00),
   466          'VCMPTRUEPS' : VectorCompare(x86_64.VCMPPS, 0x0f),
   467      }
   468  
   469      @functools.cached_property
   470      def _instr(self) -> Union[Type[PInstr], Callable[..., PInstr]]:
   471          name = self.mnemonic.upper()
   472          func = self.__instr_map__.get(name)
   473  
   474          # not found, resolve as x86_64 instruction
   475          if func is None:
   476              func = getattr(x86_64, name, None)
   477  
   478          # try with size suffix removed (only for generic instructions)
   479          if func is None and name[-1] in 'BWLQ':
   480              func = getattr(generic, name[:-1], func)
   481  
   482          # still not found, it should be an error
   483          if func is None:
   484              raise SyntaxError('unknown instruction: ' + self.mnemonic)
   485          else:
   486              return func
   487      
   488      @property
   489      def jmptab(self) -> Optional[str]:
   490          if self.mnemonic == 'leaq' and isinstance(self.operands[0], Memory) and self.operands[0].base.reg == 'rip':
   491              dis = self.operands[0].disp
   492              if dis and dis.ref.find(CLANG_JUMPTABLE_LABLE) != -1:
   493                  return dis.ref
   494  
   495      @property
   496      def _instr_size(self) -> Optional[int]:
   497          ops = self.operands
   498          key = self.mnemonic.upper()
   499  
   500          # special case of sign/zero extension instructions
   501          if key in self.__instr_size__:
   502              return self.__instr_size__[key]
   503  
   504          # check for register operands
   505          for op in ops:
   506              if isinstance(op, Register):
   507                  return None
   508  
   509          # check for size suffix, and this only applies to generic instructions
   510          if key[-1] not in self.__size_map__ or not hasattr(generic, key[:-1]):
   511              raise SyntaxError('ambiguous operand sizes')
   512          else:
   513              return self.__size_map__[key[-1]]
   514  
   515      __size_map__ = {
   516          'B': 1,
   517          'W': 2,
   518          'L': 4,
   519          'Q': 8,
   520      }
   521  
   522      __instr_size__ = {
   523          'MOVZBW' : 1,
   524          'MOVZBL' : 1,
   525          'MOVZWL' : 2,
   526          'MOVZBQ' : 1,
   527          'MOVZWQ' : 2,
   528          'MOVSBW' : 1,
   529          'MOVSBL' : 1,
   530          'MOVSWL' : 2,
   531          'MOVSBQ' : 1,
   532          'MOVSWQ' : 2,
   533          'MOVSLQ' : 4,
   534      }
   535  
   536      @staticmethod
   537      def _encode_r32(ins: PInstr) -> bytes:
   538          ret = [fn(ins.operands) for _, fn in ins.encodings]
   539          ret.sort(key = len)
   540          return ret[-1]
   541  
   542      @classmethod
   543      def _encode_ins(cls, ins: PInstr, force_rel32: bool = False) -> bytes:
   544          if not isinstance(ins, BranchInstruction):
   545              return ins.encode()
   546          elif not is_rel32(ins.operands[0]) or not force_rel32:
   547              return ins.encode()
   548          else:
   549              return cls._encode_r32(ins)
   550  
   551      def _encode_rel(self, rel: Label, sizing: bool, offset: int) -> RIPRelativeOffset:
   552          if rel.offs is not None:
   553              return RIPRelativeOffset(rel.offs)
   554          elif sizing:
   555              return RIPRelativeOffset(offset)
   556          else:
   557              raise SyntaxError('unresolved reference to name: ' + rel.name)
   558  
   559      def _encode_mem(self, mem: Memory, sizing: bool, offset: int) -> MemoryOperand:
   560          if mem.base is not None and mem.base.reg == 'rip':
   561              return self._encode_mem_rip(mem, sizing, offset)
   562          else:
   563              return self._encode_mem_std(mem, sizing, offset)
   564  
   565      def _encode_mem_rip(self, mem: Memory, sizing: bool, offset: int) -> MemoryOperand:
   566          if mem.disp is None:
   567              return MemoryOperand(RIPRelativeOffset(0))
   568          elif mem.index is not None:
   569              raise SyntaxError('%rip relative addresing does not support indexing')
   570          elif isinstance(mem.disp, Immediate):
   571              return MemoryOperand(RIPRelativeOffset(mem.disp.val))
   572          elif isinstance(mem.disp, Reference):
   573              return MemoryOperand(RIPRelativeOffset(offset if sizing else mem.disp.offset))
   574          else:
   575              raise RuntimeError('illegal memory displacement')
   576  
   577      def _encode_mem_std(self, mem: Memory, sizing: bool, offset: int) -> MemoryOperand:
   578          disp  = 0
   579          base  = None
   580          index = None
   581          scale = None
   582  
   583          # add optional base
   584          if mem.base is not None:
   585              base = mem.base.native
   586  
   587          # add optional indexing
   588          if mem.index is not None:
   589              scale = mem.index.scale
   590              index = mem.index.base.native
   591  
   592          # add optional displacement
   593          if mem.disp is not None:
   594              if isinstance(mem.disp, Immediate):
   595                  disp = mem.disp.val
   596              elif isinstance(mem.disp, Reference):
   597                  disp = offset if sizing else mem.disp.offset
   598              else:
   599                  raise RuntimeError('illegal memory displacement')
   600  
   601          # construct the memory address
   602          return MemoryOperand(
   603              size    = self._instr_size,
   604              address = MemoryAddress(base, index, scale, disp),
   605          )
   606  
   607      def _encode_operands(self, sizing: bool, offset: int) -> Iterable[Any]:
   608          for op in self.operands:
   609              if isinstance(op, Label):
   610                  yield self._encode_rel(op, sizing, offset)
   611              elif isinstance(op, Memory):
   612                  yield self._encode_mem(op, sizing, offset)
   613              elif isinstance(op, Register):
   614                  yield op.native
   615              elif isinstance(op, Immediate):
   616                  yield op.val
   617              else:
   618                  raise SyntaxError('cannot encode %s as operand' % repr(op))
   619  
   620      def _encode_branch_rel(self, rel: Label) -> str:
   621          if rel.offs is not None:
   622              return self._encode_normal_instr()
   623          else:
   624              raise RuntimeError('invalid relative branching instruction')
   625          
   626      def _raw_branch_rel(self, rel: Label) -> bytes:
   627          if rel.offs is not None:
   628              return self._raw_normal_instr()
   629          else:
   630              raise RuntimeError('invalid relative branching instruction')
   631  
   632      def _encode_branch_mem(self, mem: Memory) -> str:
   633          raise NotImplementedError('not implemented: memory indirect jump')
   634      
   635      def _raw_branch_mem(self, mem: Memory) -> bytes:
   636          raise NotImplementedError('not implemented: memory indirect jump')
   637  
   638      def _encode_branch_reg(self, reg: Register) -> str:
   639          if reg.reg == 'rip':
   640              raise SyntaxError('%rip cannot be used as a jump target')
   641          elif self.mnemonic != 'jmpq':
   642              raise SyntaxError('invalid indirect jump for instruction: ' + self.mnemonic)
   643          else:
   644              return x86_64.JMP(reg.native).format('go')
   645          
   646      def _raw_branch_reg(self, reg: Register) -> bytes:
   647          if reg.reg == 'rip':
   648              raise SyntaxError('%rip cannot be used as a jump target')
   649          elif self.mnemonic != 'jmpq':
   650              raise SyntaxError('invalid indirect jump for instruction: ' + self.mnemonic)
   651          else:
   652              return x86_64.JMP(reg.native).encode()
   653  
   654      def _encode_branch_instr(self) -> str:
   655          if len(self.operands) != 1:
   656              raise RuntimeError('illegal branch instruction')
   657          elif isinstance(self.operands[0], Label):
   658              return self._encode_branch_rel(self.operands[0])
   659          elif isinstance(self.operands[0], Memory):
   660              return self._encode_branch_mem(self.operands[0])
   661          elif isinstance(self.operands[0], Register):
   662              return self._encode_branch_reg(self.operands[0])
   663          else:
   664              raise RuntimeError('invalid operand type ' + repr(self.operands[0]))
   665          
   666      def _raw_branch_instr(self) -> str:
   667          if len(self.operands) != 1:
   668              raise RuntimeError('illegal branch instruction')
   669          elif isinstance(self.operands[0], Label):
   670              return self._raw_branch_rel(self.operands[0])
   671          elif isinstance(self.operands[0], Memory):
   672              return self._raw_branch_mem(self.operands[0])
   673          elif isinstance(self.operands[0], Register):
   674              return self._raw_branch_reg(self.operands[0])
   675          else:
   676              raise RuntimeError('invalid operand type ' + repr(self.operands[0]))
   677  
   678      def _encode_normal_instr(self) -> str:
   679          ops = self._encode_operands(False, 0)
   680          ret = self._instr(*list(ops)[::-1])
   681  
   682          # encode all instructions as raw bytes
   683          if not self.is_branch_label:
   684              return self.encode(self._encode_ins(ret), str(self))
   685          else:
   686              return self.encode(self._encode_ins(ret, force_rel32 = True), '%s, $%s(%%rip)' % (self, self.operands[0].offs))
   687          
   688      def _raw_normal_instr(self) -> str:
   689          ops = self._encode_operands(False, 0)
   690          ret = self._instr(*list(ops)[::-1])
   691  
   692          # encode all instructions as raw bytes
   693          if not self.is_branch_label:
   694              return self._encode_ins(ret)
   695          else:
   696              return self._encode_ins(ret, force_rel32 = True)
   697  
   698  
   699      @property
   700      def size(self) -> int:
   701          return self.encoded_size(0)
   702  
   703      @functools.cached_property
   704      def encoded(self) -> str:
   705          if self.is_branch:
   706              return self._encode_branch_instr()
   707          else:
   708              return self._encode_normal_instr()
   709          
   710      def raw(self) -> bytes:
   711          if self.is_branch:
   712              return self._raw_branch_instr()
   713          else:
   714              return self._raw_normal_instr()
   715  
   716      @functools.cached_property
   717      def is_return(self) -> bool:
   718          return self._instr is x86_64.RET
   719  
   720      @functools.cached_property
   721      def is_invoke(self) -> bool:
   722          return self._instr is x86_64.CALL
   723  
   724      @functools.cached_property
   725      def is_branch(self) -> bool:
   726          try:
   727              return self.is_invoke or issubclass(self._instr, BranchInstruction)
   728          except TypeError:
   729              return False
   730  
   731      @functools.cached_property
   732      def is_jmp(self) -> bool:
   733          return self._instr is x86_64.JMP
   734      
   735      @functools.cached_property
   736      def is_jmpq(self) -> bool:
   737          return self.mnemonic == 'jmpq'
   738  
   739      @property
   740      def is_branch_label(self) -> bool:
   741          return self.is_branch and isinstance(self.operands[0], Label)
   742  
   743      def encoded_size(self, offset: int) -> int:
   744          op = self._encode_operands(True, offset)
   745          return len(self._encode_ins(self._instr(*list(op)[::-1]), force_rel32 = True))
   746  
   747      @classmethod
   748      def parse(cls, line: str) -> 'Instruction':
   749          lex = Tokenizer(line)
   750          ntk = lex.next()
   751  
   752          # the first token must be a name
   753          if ntk.tag != TOKEN_NAME:
   754              raise SyntaxError('mnemonic expected, got ' + repr(ntk))
   755          else:
   756              return cls(ntk.val, cls._parse_operands(lex))
   757  
   758      @staticmethod
   759      def encode(buf: bytes, comments: str = '') -> str:
   760          i = 0
   761          r = []
   762          n = len(buf)
   763  
   764          # try "QUAD" first
   765          while i < n - 7:
   766              r.append('QUAD $0x%016x' % int.from_bytes(buf[i:i + 8], 'little'))
   767              i += 8
   768  
   769          # then "LONG"
   770          while i < n - 3:
   771              r.append('LONG $0x%08x' % int.from_bytes(buf[i:i + 4], 'little'))
   772              i += 4
   773  
   774          # then "SHORT"
   775          while i < n - 1:
   776              r.append('WORD $0x%04x' % int.from_bytes(buf[i:i + 2], 'little'))
   777              i += 2
   778  
   779          # then "BYTE"
   780          while i < n:
   781              r.append('BYTE $0x%02x' % buf[i])
   782              i += 1
   783  
   784          # join them together, and attach the comment if any
   785          if not comments:
   786              return '; '.join(r)
   787          else:
   788              return '%s  // %s' % ('; '.join(r), comments)
   789  
   790      Reg  = Optional[Register]
   791      Disp = Optional[Displacement]
   792  
   793      @classmethod
   794      def _parse_mend(cls, ntk: Token, base: Reg, index: Register, scale: int, disp: Disp) -> Operand:
   795          if ntk.tag != TOKEN_PUNC or ntk.val != ')':
   796              raise SyntaxError('")" expected, got ' + repr(ntk))
   797          else:
   798              return Memory(base, disp, Index(index, scale))
   799  
   800      @classmethod
   801      def _parse_base(cls, lex: Tokenizer, ntk: Token, disp: Disp) -> Operand:
   802          if ntk.tag == TOKEN_REG:
   803              return cls._parse_idelim(lex, lex.next(), Register(ntk.val), disp)
   804          elif ntk.tag == TOKEN_PUNC and ntk.val == ',':
   805              return cls._parse_ibase(lex, lex.next(), None, disp)
   806          else:
   807              raise SyntaxError('register expected, got ' + repr(ntk))
   808  
   809      @classmethod
   810      def _parse_ibase(cls, lex: Tokenizer, ntk: Token, base: Reg, disp: Disp) -> Operand:
   811          if ntk.tag != TOKEN_REG:
   812              raise SyntaxError('register expected, got ' + repr(ntk))
   813          else:
   814              return cls._parse_sdelim(lex, lex.next(), base, Register(ntk.val), disp)
   815  
   816      @classmethod
   817      def _parse_idelim(cls, lex: Tokenizer, ntk: Token, base: Reg, disp: Disp) -> Operand:
   818          if ntk.tag == TOKEN_END:
   819              raise SyntaxError('unexpected EOF when parsing memory operands')
   820          elif ntk.tag == TOKEN_PUNC and ntk.val == ')':
   821              return Memory(base, disp, None)
   822          elif ntk.tag == TOKEN_PUNC and ntk.val == ',':
   823              return cls._parse_ibase(lex, lex.next(), base, disp)
   824          else:
   825              raise SyntaxError('"," or ")" expected, got ' + repr(ntk))
   826  
   827      @classmethod
   828      def _parse_iscale(cls, lex: Tokenizer, ntk: Token, base: Reg, index: Register, disp: Disp) -> Operand:
   829          if ntk.tag != TOKEN_NUM:
   830              raise SyntaxError('integer expected, got ' + repr(ntk))
   831          elif ntk.val not in (1, 2, 4, 8):
   832              raise SyntaxError('indexing scale can only be 1, 2, 4 or 8')
   833          else:
   834              return cls._parse_mend(lex.next(), base, index, ntk.val, disp)
   835  
   836      @classmethod
   837      def _parse_sdelim(cls, lex: Tokenizer, ntk: Token, base: Reg, index: Register, disp: Disp) -> Operand:
   838          if ntk.tag == TOKEN_END:
   839              raise SyntaxError('unexpected EOF when parsing memory operands')
   840          elif ntk.tag == TOKEN_PUNC and ntk.val == ')':
   841              return Memory(base, disp, Index(index))
   842          elif ntk.tag == TOKEN_PUNC and ntk.val == ',':
   843              return cls._parse_iscale(lex, lex.next(), base, index, disp)
   844          else:
   845              raise SyntaxError('"," or ")" expected, got ' + repr(ntk))
   846  
   847      @classmethod
   848      def _parse_refmem(cls, lex: Tokenizer, ntk: Token, ref: str) -> Operand:
   849          if ntk.tag == TOKEN_END:
   850              return Label(ref)
   851          elif ntk.tag == TOKEN_PUNC and ntk.val == '(' :
   852              return cls._parse_memory(lex, ntk, Reference(ref, 0))
   853          elif ntk.tag == TOKEN_NUM:
   854              ntk = lex.next()
   855              if ntk.tag == TOKEN_PUNC and ntk.val == '(':
   856                  return cls._parse_refmem(lex, ntk, Reference(ref, ntk.val))
   857          
   858          raise SyntaxError(f'identifier "{ref}" must either be a label or a displacement reference')
   859  
   860      @classmethod
   861      def _parse_memory(cls, lex: Tokenizer, ntk: Token, disp: Optional[Displacement]) -> Operand:
   862          if ntk.tag != TOKEN_PUNC or ntk.val != '(':
   863              raise SyntaxError('"(" expected, got ' + repr(ntk))
   864          else:
   865              return cls._parse_base(lex, lex.next(), disp)
   866  
   867      @classmethod
   868      def _parse_operand(cls, lex: Tokenizer, ntk: Token, can_indir: bool = True) -> Operand:
   869          if ntk.tag == TOKEN_REG:
   870              return Register(ntk.val)
   871          elif ntk.tag == TOKEN_IMM:
   872              return Immediate(ntk.val)
   873          elif ntk.tag == TOKEN_NUM:
   874              return cls._parse_memory(lex, lex.next(), Immediate(ntk.val))
   875          elif ntk.tag == TOKEN_NAME:
   876              return cls._parse_refmem(lex, lex.next(), ntk.val)
   877          elif ntk.tag == TOKEN_PUNC and ntk.val == '(':
   878              return cls._parse_memory(lex, ntk, None)
   879          elif ntk.tag == TOKEN_PUNC and ntk.val == '*' and can_indir:
   880              return cls._parse_operand(lex, lex.next(), False)
   881          else:
   882              raise SyntaxError('invalid token: ' + repr(ntk))
   883  
   884      @classmethod
   885      def _parse_operands(cls, lex: Tokenizer) -> List[Operand]:
   886          ret = []
   887          ntk = lex.next()
   888  
   889          # check for empty operand
   890          if ntk.tag == TOKEN_END:
   891              return []
   892  
   893          # parse every operand
   894          while True:
   895              ret.append(cls._parse_operand(lex, ntk))
   896              ntk = lex.next()
   897  
   898              # check for the ',' delimiter or the end of input
   899              if ntk.tag == TOKEN_PUNC and ntk.val == ',':
   900                  ntk = lex.next()
   901              elif ntk.tag != TOKEN_END:
   902                  raise SyntaxError('"," expected, got ' + repr(ntk))
   903              else:
   904                  return ret
   905  
   906  ### Prototype Parser ###
   907  
   908  ARGS_ORDER_C = [
   909      Register('rdi'),
   910      Register('rsi'),
   911      Register('rdx'),
   912      Register('rcx'),
   913      Register('r8'),
   914      Register('r9'),
   915  ]
   916  
   917  ARGS_ORDER_GO = [
   918      Register('rax'),
   919      Register('rbx'),
   920      Register('rcx'),
   921      Register('rdi'),
   922      Register('rsi'),
   923      Register('r8'),
   924  ]
   925  
   926  FPARGS_ORDER = [
   927      Register('xmm0'),
   928      Register('xmm1'),
   929      Register('xmm2'),
   930      Register('xmm3'),
   931      Register('xmm4'),
   932      Register('xmm5'),
   933      Register('xmm6'),
   934      Register('xmm7'),
   935  ]
   936  
   937  class Parameter:
   938      name : str
   939      size : int
   940      creg : Register
   941      goreg: Register
   942  
   943      def __init__(self, name: str, size: int, reg: Register, goreg: Register):
   944          self.creg  = reg
   945          self.goreg = reg
   946          self.name = name
   947          self.size = size
   948  
   949      def __repr__(self):
   950          return '<ARG %s(%d): %s>' % (self.name, self.size, self.creg)
   951  
   952  class Pcsp:
   953      entry: int
   954      maxpc: int
   955      out  : List[Tuple[int, int]]
   956      pc   : int
   957      sp   : int
   958      
   959      def __init__(self, entry: int):
   960          self.out = []
   961          self.maxpc = entry
   962          self.entry = entry
   963          self.pc = entry
   964          self.sp = 0
   965      
   966      def __str__(self) -> str:
   967          ret = '[][2]uint32{\n'
   968          for pc, sp in self.out:
   969              ret += '        {%d, %d},\n' % (pc, sp)
   970          return ret + '    }'
   971      
   972      def optimize(self):
   973          # push the last record
   974          self.out.append((self.pc - self.entry, self.sp))
   975          # sort by pc
   976          self.out.sort(key=lambda x: x[0])
   977          # NOTICE: first pair {1, 0} to be compitable with golang
   978          tmp = [(1, 0)]
   979          lpc, lsp = 0, -1
   980          for pc, sp in self.out:
   981              # sp changed, push new record
   982              if pc != lpc and sp != lsp:
   983                      tmp.append((pc, sp))
   984              # sp unchanged, replace with the higher pc
   985              if pc != lpc and sp == lsp:
   986                  if len(tmp) > 0:
   987                      tmp.pop(-1)
   988                  tmp.append((pc, sp))
   989                  
   990              lpc, lsp = pc, sp
   991          self.out = tmp
   992      
   993      def update(self, dpc: int, dsp: int):
   994          self.out.append((self.pc - self.entry, self.sp))
   995          self.pc += dpc
   996          self.sp += dsp
   997          if self.pc > self.maxpc:
   998              self.maxpc = self.pc
   999  
  1000  class Prototype:
  1001      args: List[Parameter]
  1002      retv: Optional[Parameter]
  1003  
  1004      def __init__(self, retv: Optional[Parameter], args: List[Parameter]):
  1005          self.retv = retv
  1006          self.args = args
  1007  
  1008      def __repr__(self):
  1009          if self.retv is None:
  1010              return '<PROTO (%s)>' % repr(self.args)
  1011          else:
  1012              return '<PROTO (%r) -> %r>' % (self.args, self.retv)
  1013  
  1014      @property
  1015      def argspace(self) -> int:
  1016          return sum(
  1017              [v.size for v in self.args],
  1018              (0 if self.retv is None else self.retv.size)
  1019          )
  1020  
  1021  class PrototypeMap(Dict[str, Prototype]):
  1022      @staticmethod
  1023      def _dv(c: str) -> int:
  1024          if c == '(':
  1025              return 1
  1026          elif c == ')':
  1027              return -1
  1028          else:
  1029              return 0
  1030  
  1031      @staticmethod
  1032      def _tk(s: str, p: str) -> bool:
  1033          return s.startswith(p) and (s == p or s[len(p)].isspace())
  1034      
  1035      @classmethod
  1036      def _punc(cls, s: str) -> bool:
  1037          return s in cls.__puncs_
  1038  
  1039      @staticmethod
  1040      def _err(msg: str) -> SyntaxError:
  1041          return SyntaxError(
  1042              msg + ', ' +
  1043              'the parser integrated in this tool is just a text-based parser, ' +
  1044              'so please keep the companion .go file as simple as possible and do not use defined types'
  1045          )
  1046  
  1047      @staticmethod
  1048      def _align(nb: int) -> int:
  1049          return (((nb - 1) >> 3) + 1) << 3
  1050  
  1051      @classmethod
  1052      def _retv(cls, ret: str) -> Tuple[str, int, Register, Register]:
  1053          name, size, xmm = cls._args(ret)
  1054          reg = Register('xmm0') if xmm else Register('rax')
  1055          return name, size, reg, reg
  1056  
  1057      @classmethod
  1058      def _args(cls, arg: str, sv: str = '') -> Tuple[str, int, bool]:
  1059          while True:
  1060              if not arg:
  1061                  raise SyntaxError('missing type for parameter: ' + sv)
  1062              elif arg[0] != '_' and not arg[0].isalnum():
  1063                  return (sv,) + cls._size(arg.strip())
  1064              elif not sv and arg[0].isdigit():
  1065                  raise SyntaxError('invalid character: ' + repr(arg[0]))
  1066              else:
  1067                  sv += arg[0]
  1068                  arg = arg[1:]
  1069  
  1070      @classmethod
  1071      def _size(cls, name: str) -> Tuple[int, bool]:
  1072          if name[0] == '*':
  1073              return cls._align(8), False
  1074          elif name in ('int8', 'uint8', 'byte', 'bool'):
  1075              return cls._align(1), False
  1076          elif name in ('int16', 'uint16'):
  1077              return cls._align(2), False
  1078          elif name == 'float32':
  1079              return cls._align(4), True
  1080          elif name in ('int32', 'uint32', 'rune'):
  1081              return cls._align(4), False
  1082          elif name == 'float64':
  1083              return cls._align(8), True
  1084          elif name in ('int64', 'uint64', 'uintptr', 'int', 'Pointer', 'unsafe.Pointer'):
  1085              return cls._align(8), False
  1086          else:
  1087              raise cls._err('unrecognized type "%s"' % name)
  1088  
  1089      @classmethod
  1090      def _func(cls, src: List[str], idx: int, depth: int = 0) -> Tuple[str, int]:
  1091          for i in range(idx, len(src)):
  1092              for x in map(cls._dv, src[i]):
  1093                  if depth + x >= 0:
  1094                      depth += x
  1095                  else:
  1096                      raise cls._err('encountered ")" more than "(" on line %d' % (i + 1))
  1097              else:
  1098                  if depth == 0:
  1099                      return ' '.join(src[idx:i + 1]), i + 1
  1100          else:
  1101              raise cls._err('unexpected EOF when parsing function signatures')
  1102  
  1103      @classmethod
  1104      def parse(cls, src: str) -> Tuple[str, 'PrototypeMap']:
  1105          idx = 0
  1106          pkg = ''
  1107          ret = PrototypeMap()
  1108          buf = src.splitlines()
  1109          
  1110          # scan through all the lines
  1111          while idx < len(buf):
  1112              line = buf[idx]
  1113              line = line.strip()
  1114  
  1115              # skip empty lines
  1116              if not line:
  1117                  idx += 1
  1118                  continue
  1119  
  1120              # check for package name
  1121              if cls._tk(line, 'package'):
  1122                  idx, pkg = idx + 1, line[7:].strip().split()[0]
  1123                  continue
  1124  
  1125              if OUTPUT_RAW:
  1126                  
  1127                  # extract funcname like "[var ]{funcname} = func(..."
  1128                  end = line.find('func(')
  1129                  if end == -1:
  1130                      idx += 1
  1131                      continue
  1132                  name = line[:end].strip()
  1133                  if name.startswith('var '):
  1134                      name = name[4:].strip()
  1135                  
  1136                  # function names must be identifiers
  1137                  if not name.isidentifier():
  1138                      raise cls._err('invalid function prototype: ' + name)
  1139                  
  1140                  # register a empty prototype
  1141                  ret[name] = Prototype(None, [])
  1142                  idx += 1
  1143                  
  1144              else:      
  1145                                
  1146                  # only cares about those functions that does not have bodies
  1147                  if line[-1] == '{' or not cls._tk(line, 'func'):
  1148                      idx += 1
  1149                      continue
  1150  
  1151                  # prevent type-aliasing primitive types into other names
  1152                  if cls._tk(line, 'type'):
  1153                      raise cls._err('please do not declare any type with in the companion .go file')
  1154  
  1155                  # find the next function declaration
  1156                  decl, pos = cls._func(buf, idx)
  1157                  func, idx = decl[4:].strip(), pos
  1158  
  1159                  # find the beginning '('
  1160                  nd = 1
  1161                  pos = func.find('(')
  1162  
  1163                  # must have a '('
  1164                  if pos == -1:
  1165                      raise cls._err('invalid function prototype: ' + decl)
  1166  
  1167                  # extract the name and signature
  1168                  args = ''
  1169                  name = func[:pos].strip()
  1170                  func = func[pos + 1:].strip()
  1171  
  1172                  # skip the method declaration
  1173                  if not name:
  1174                      continue
  1175  
  1176                  # function names must be identifiers
  1177                  if not name.isidentifier():
  1178                      raise cls._err('invalid function prototype: ' + decl)
  1179  
  1180                  # extract the argument list
  1181                  while nd and func:
  1182                      nch  = func[0]
  1183                      func = func[1:]
  1184  
  1185                      # adjust the nesting level
  1186                      nd   += cls._dv(nch)
  1187                      args += nch
  1188  
  1189                  # check for EOF
  1190                  if not nd:
  1191                      func = func.strip()
  1192                  else:
  1193                      raise cls._err('unexpected EOF when parsing function prototype: ' + decl)
  1194  
  1195                  # check for multiple returns
  1196                  if ',' in func:
  1197                      raise cls._err('can only return a single value (detected by looking for "," within the return list)')
  1198  
  1199                  # check for return signature
  1200                  if not func:
  1201                      retv = None
  1202                  elif func[0] == '(' and func[-1] == ')':
  1203                      retv = Parameter(*cls._retv(func[1:-1]))
  1204                  else:
  1205                      raise SyntaxError('badly formatted return argument (please use parenthesis and proper arguments naming): ' + func)
  1206  
  1207                  # extract the argument list
  1208                  if not args[:-1]:
  1209                      args, alens, axmm = [], [], []
  1210                  else:
  1211                      args, alens, axmm = list(zip(*[cls._args(v.strip()) for v in args[:-1].split(',')]))
  1212  
  1213                  # check for the result
  1214                  cregs = []
  1215                  goregs = []
  1216                  idxs = [0, 0]
  1217  
  1218                  # split the integer & floating point registers
  1219                  for xmm in axmm:
  1220                      key = 0 if xmm else 1
  1221                      seq = FPARGS_ORDER if xmm else ARGS_ORDER_C
  1222                      goseq = FPARGS_ORDER if xmm else ARGS_ORDER_GO
  1223  
  1224                      # check the argument count
  1225                      if idxs[key] >= len(seq):
  1226                          raise cls._err("too many arguments, consider pack some into a pointer")
  1227  
  1228                      # add the register
  1229                      cregs.append(seq[idxs[key]])
  1230                      goregs.append(goseq[idxs[key]])
  1231                      idxs[key] += 1
  1232  
  1233                  # register the prototype
  1234                  ret[name] = Prototype(retv, [
  1235                      Parameter(arg, size, creg, goreg)
  1236                      for arg, size, creg, goreg in zip(args, alens, cregs, goregs)
  1237                  ])
  1238  
  1239          # all done
  1240          return pkg, ret
  1241  
  1242  ### Assembly Source Parser ###
  1243  
  1244  ESC_IDLE = 0    # escape parser is idleing
  1245  ESC_ISTR = 1    # currently inside a string
  1246  ESC_BKSL = 2    # encountered backslash, prepare for escape sequences
  1247  ESC_HEX0 = 3    # expect the first hexadecimal character of a "\x" escape
  1248  ESC_HEX1 = 4    # expect the second hexadecimal character of a "\x" escape
  1249  ESC_OCT1 = 5    # expect the second octal character of a "\000" escape
  1250  ESC_OCT2 = 6    # expect the third octal character of a "\000" escape
  1251  
  1252  class Command:
  1253      cmd  : str
  1254      args : List[Union[str, bytes]]
  1255  
  1256      def __init__(self, cmd: str, args: List[Union[str, bytes]]):
  1257          self.cmd  = cmd
  1258          self.args = args
  1259  
  1260      def __repr__(self):
  1261          return '<CMD %s %s>' % (self.cmd, ', '.join(map(repr, self.args)))
  1262  
  1263      @classmethod
  1264      def parse(cls, src: str) -> 'Command':
  1265          val = src.split(None, 1)
  1266          cmd = val[0]
  1267  
  1268          # no parameters
  1269          if len(val) == 1:
  1270              return cls(cmd, [])
  1271  
  1272          # extract the argument string
  1273          idx = 0
  1274          esc = 0
  1275          pos = None
  1276          args = []
  1277          vstr = val[1]
  1278  
  1279          # scan through the whole string
  1280          while idx < len(vstr):
  1281              nch = vstr[idx]
  1282              idx += 1
  1283  
  1284              # mark the start of the argument
  1285              if pos is None:
  1286                  pos = idx - 1
  1287  
  1288              # encountered the delimiter outside of a string
  1289              if nch == ',' and esc == ESC_IDLE:
  1290                  pos, p = None, pos
  1291                  args.append(vstr[p:idx - 1].strip())
  1292  
  1293              # start of a string
  1294              elif nch == '"' and esc == ESC_IDLE:
  1295                  esc = ESC_ISTR
  1296  
  1297              # end of string
  1298              elif nch == '"' and esc == ESC_ISTR:
  1299                  esc = ESC_IDLE
  1300                  pos, p = None, pos
  1301                  args.append(vstr[p:idx].strip()[1:-1].encode('utf-8').decode('unicode_escape'))
  1302  
  1303              # escape characters
  1304              elif nch == '\\' and esc == ESC_ISTR:
  1305                  esc = ESC_BKSL
  1306  
  1307              # hexadecimal escape characters (3 chars)
  1308              elif esc == ESC_BKSL and nch == 'x':
  1309                  esc = ESC_HEX0
  1310  
  1311              # octal escape characters (3 chars)
  1312              elif esc == ESC_BKSL and nch in string.octdigits:
  1313                  esc = ESC_OCT1
  1314  
  1315              # generic escape characters (single char)
  1316              elif esc == ESC_BKSL and nch in ('a', 'b', 'f', 'r', 'n', 't', 'v', '"', '\\'):
  1317                  esc = ESC_ISTR
  1318  
  1319              # invalid escape sequence
  1320              elif esc == ESC_BKSL:
  1321                  raise SyntaxError('invalid escape character: ' + repr(nch))
  1322  
  1323              # normal characters, simply advance to the next character
  1324              elif esc in (ESC_IDLE, ESC_ISTR):
  1325                  pass
  1326  
  1327              # hexadecimal escape characters
  1328              elif esc in (ESC_HEX0, ESC_HEX1) and nch.lower() in string.hexdigits:
  1329                  esc = ESC_HEX1 if esc == ESC_HEX0 else ESC_ISTR
  1330  
  1331              # invalid hexadecimal character
  1332              elif esc in (ESC_HEX0, ESC_HEX1):
  1333                  raise SyntaxError('invalid hexdecimal character: ' + repr(nch))
  1334  
  1335              # octal escape characters
  1336              elif esc in (ESC_OCT1, ESC_OCT2) and nch.lower() in string.octdigits:
  1337                  esc = ESC_OCT2 if esc == ESC_OCT1 else ESC_ISTR
  1338  
  1339              # at most 3 octal digits
  1340              elif esc in (ESC_OCT1, ESC_OCT2):
  1341                  esc = ESC_ISTR
  1342  
  1343              # illegal state, should not happen
  1344              else:
  1345                  raise RuntimeError('illegal state: %d' % esc)
  1346  
  1347          # check for the last argument
  1348          if pos is None:
  1349              return cls(cmd, args)
  1350  
  1351          # add the last argument and build the command
  1352          args.append(vstr[pos:].strip())
  1353          return cls(cmd, args)
  1354  
  1355  class Expression:
  1356      pos: int
  1357      src: str
  1358  
  1359      def __init__(self, src: str):
  1360          self.pos = 0
  1361          self.src = src
  1362  
  1363      @property
  1364      def _ch(self) -> str:
  1365          return self.src[self.pos]
  1366  
  1367      @property
  1368      def _eof(self) -> bool:
  1369          return self.pos >= len(self.src)
  1370  
  1371      def _rch(self) -> str:
  1372          pos, self.pos = self.pos, self.pos + 1
  1373          return self.src[pos]
  1374  
  1375      def _hex(self, ch: str) -> bool:
  1376          if len(ch) == 1 and ch[0] == '0':
  1377              return self._ch.lower() == 'x'
  1378          elif len(ch) <= 1 or ch[1].lower() != 'x':
  1379              return self._ch.isdigit()
  1380          else:
  1381              return self._ch in string.hexdigits
  1382  
  1383      def _int(self, ch: str) -> Token:
  1384          while not self._eof and self._hex(ch):
  1385              ch += self._rch()
  1386          else:
  1387              if ch.lower().startswith('0x'):
  1388                  return Token.num(int(ch, 16))
  1389              elif ch[0] == '0':
  1390                  return Token.num(int(ch, 8))
  1391              else:
  1392                  return Token.num(int(ch))
  1393  
  1394      def _name(self, ch: str) -> Token:
  1395          while not self._eof and (self._ch == '_' or self._ch.isalnum()):
  1396              ch += self._rch()
  1397          else:
  1398              return Token.name(ch)
  1399  
  1400      def _read(self, ch: str) -> Token:
  1401          if ch.isdigit():
  1402              return self._int(ch)
  1403          elif ch.isidentifier():
  1404              return self._name(ch)
  1405          elif ch in ('*', '<', '>') and not self._eof and self._ch == ch:
  1406              return Token.punc(self._rch() * 2)
  1407          elif ch in ('+', '-', '*', '/', '%', '&', '|', '^', '~', '(', ')'):
  1408              return Token.punc(ch)
  1409          else:
  1410              raise SyntaxError('invalid character: ' + repr(ch))
  1411  
  1412      def _peek(self) -> Optional[Token]:
  1413          pos = self.pos
  1414          ret = self._next()
  1415          self.pos = pos
  1416          return ret
  1417  
  1418      def _next(self) -> Optional[Token]:
  1419          while not self._eof and self._ch.isspace():
  1420              self.pos += 1
  1421          else:
  1422              return Token.end() if self._eof else self._read(self._rch())
  1423  
  1424      def _grab(self, tk: Token, getvalue: Callable[[str], int]) -> int:
  1425          if tk.tag == TOKEN_NUM:
  1426              return tk.val
  1427          elif tk.tag == TOKEN_NAME:
  1428              return getvalue(tk.val)
  1429          else:
  1430              raise SyntaxError('integer or identifier expected, got ' + repr(tk))
  1431  
  1432      __pred__ = [
  1433          {'<<', '>>'},
  1434          {'|'},
  1435          {'^'},
  1436          {'&'},
  1437          {'+', '-'},
  1438          {'*', '/', '%'},
  1439          {'**'},
  1440      ]
  1441  
  1442      __binary__ = {
  1443          '+'  : lambda a, b: a + b,
  1444          '-'  : lambda a, b: a - b,
  1445          '*'  : lambda a, b: a * b,
  1446          '/'  : lambda a, b: a / b,
  1447          '%'  : lambda a, b: a % b,
  1448          '&'  : lambda a, b: a & b,
  1449          '^'  : lambda a, b: a ^ b,
  1450          '|'  : lambda a, b: a | b,
  1451          '<<' : lambda a, b: a << b,
  1452          '>>' : lambda a, b: a >> b,
  1453          '**' : lambda a, b: a ** b,
  1454      }
  1455  
  1456      def _eval(self, op: str, v1: int, v2: int) -> int:
  1457          return self.__binary__[op](v1, v2)
  1458  
  1459      def _nest(self, nest: int, getvalue: Callable[[str], int]) -> int:
  1460          ret = self._expr(0, nest + 1, getvalue)
  1461          ntk = self._next()
  1462  
  1463          # it must follows with a ')' operator
  1464          if ntk.tag != TOKEN_PUNC or ntk.val != ')':
  1465              raise SyntaxError('")" expected, got ' + repr(ntk))
  1466          else:
  1467              return ret
  1468  
  1469      def _unit(self, nest: int, getvalue: Callable[[str], int]) -> int:
  1470          tk = self._next()
  1471          tt, tv = tk.tag, tk.val
  1472  
  1473          # check for unary operators
  1474          if tt == TOKEN_NUM:
  1475              return tv
  1476          elif tt == TOKEN_NAME:
  1477              return getvalue(tv)
  1478          elif tt == TOKEN_PUNC and tv == '(':
  1479              return self._nest(nest, getvalue)
  1480          elif tt == TOKEN_PUNC and tv == '+':
  1481              return self._unit(nest, getvalue)
  1482          elif tt == TOKEN_PUNC and tv == '-':
  1483              return -self._unit(nest, getvalue)
  1484          elif tt == TOKEN_PUNC and tv == '~':
  1485              return ~self._unit(nest, getvalue)
  1486          else:
  1487              raise SyntaxError('integer, unary operator or nested expression expected, got ' + repr(tk))
  1488  
  1489      def _term(self, pred: int, nest: int, getvalue: Callable[[str], int]) -> int:
  1490          lv = self._expr(pred + 1, nest, getvalue)
  1491          tk = self._peek()
  1492  
  1493          # scan to the end
  1494          while True:
  1495              tt = tk.tag
  1496              tv = tk.val
  1497  
  1498              # encountered EOF
  1499              if tt == TOKEN_END:
  1500                  return lv
  1501  
  1502              # must be an operator here
  1503              if tt != TOKEN_PUNC:
  1504                  raise SyntaxError('operator expected, got ' + repr(tk))
  1505  
  1506              # check for the operator precedence
  1507              if tv not in self.__pred__[pred]:
  1508                  return lv
  1509  
  1510              # apply the operator
  1511              op = self._next().val
  1512              rv = self._expr(pred + 1, nest, getvalue)
  1513              lv = self._eval(op, lv, rv)
  1514              tk = self._peek()
  1515  
  1516      def _expr(self, pred: int, nest: int, getvalue: Callable[[str], int]) -> int:
  1517          if pred >= len(self.__pred__):
  1518              return self._unit(nest, getvalue)
  1519          else:
  1520              return self._term(pred, nest, getvalue)
  1521  
  1522      def eval(self, getvalue: Callable[[str], int]) -> int:
  1523          return self._expr(0, 0, getvalue)
  1524  
  1525  
  1526  class Instr:
  1527      ALIGN_WIDTH = 48
  1528      len   : int                     = NotImplemented
  1529      instr : Union[str, Instruction] = NotImplemented
  1530  
  1531      def size(self, pc: int) -> int:
  1532          return self.len
  1533  
  1534      def formatted(self, pc: int) -> str:
  1535          raise NotImplementedError
  1536      
  1537      @staticmethod
  1538      def raw_formatted(bs: bytes, comm: str, pc: int) -> str:
  1539          t = '\t'
  1540          if bs:
  1541              for b in bs:
  1542                  t +='0x%02x, ' % b
  1543              # if len(bs)<Instr.ALIGN_WIDTH:
  1544              #     t += '\b' * (Instr.ALIGN_WIDTH - len(bs))
  1545          return '%s//%s%s' % (t, ('0x%08x ' % pc) if pc else ' ', comm)
  1546  class RawInstr(Instr):
  1547      bs: bytes
  1548      def __init__(self, size: int, instr: str, bs: bytes):
  1549          self.len = size
  1550          self.instr = instr
  1551          self.bs = bs
  1552  
  1553      def formatted(self, _: int) -> str:
  1554          return '\t' + self.instr
  1555      
  1556      def raw_formatted(self, pc: int) -> str:
  1557          return Instr.raw_formatted(self.bs, self.instr, pc)
  1558          
  1559  class IntInstr(Instr):
  1560      comm: str
  1561      func: Callable[[], int]
  1562  
  1563      def __init__(self, size: int, func: Callable[[], int], comments: str = ''):
  1564          self.len = size
  1565          self.func = func
  1566          self.comm = comments
  1567  
  1568      @property
  1569      def instr(self) -> str:
  1570          return Instruction.encode(self.func().to_bytes(self.len, 'little'), self.comm)
  1571  
  1572      def formatted(self, _: int) -> str:
  1573          return '\t' + self.instr
  1574      
  1575      def raw_formatted(self, pc: int) -> str:
  1576          return Instr.raw_formatted(self.func().to_bytes(self.len, 'little'), self.comm, pc)
  1577  
  1578  class X86Instr(Instr):
  1579      def __init__(self, instr: Instruction):
  1580          self.len = instr.size
  1581          self.instr = instr
  1582  
  1583      def resize(self, size: int) -> int:
  1584          self.len = size
  1585          return size
  1586  
  1587      def formatted(self, _: int) -> str:
  1588          return '\t' + str(self.instr.encoded)
  1589      
  1590      def raw_formatted(self, pc: int) -> str:
  1591          return Instr.raw_formatted(self.instr._raw_normal_instr(), str(self.instr), pc)
  1592  
  1593  class LabelInstr(Instr):
  1594      def __init__(self, name: str):
  1595          self.len = 0
  1596          self.instr = name
  1597  
  1598      def formatted(self, _: int) -> str:
  1599          if self.instr.isidentifier():
  1600              return self.instr + ':'
  1601          else:
  1602              return '_LB_%08x: // %s' % (hash(self.instr) & 0xffffffff, self.instr)
  1603          
  1604      def raw_formatted(self, pc: int) -> str:
  1605          return Instr.raw_formatted(None, str(self.instr), pc)
  1606  
  1607  class BranchInstr(Instr):
  1608      def __init__(self, instr: Instruction):
  1609          self.len = instr.size
  1610          self.instr = instr
  1611  
  1612      def formatted(self, _: int) -> str:
  1613          return '\t' + self.instr.encoded
  1614  
  1615      def raw_formatted(self, pc: int) -> str:
  1616          return Instr.raw_formatted(self.instr._raw_branch_instr(), str(self.instr), pc)
  1617      
  1618  class CommentInstr(Instr):
  1619      def __init__(self, text: str):
  1620          self.len = 0
  1621          self.instr = '// ' + text
  1622  
  1623      def formatted(self, _: int) -> str:
  1624          return '\t' + self.instr
  1625  
  1626      def raw_formatted(self, pc: int) -> str:
  1627          return  Instr.raw_formatted(None, str(self.instr), None)
  1628      
  1629  class AlignmentInstr(Instr):
  1630      bits: int
  1631      fill: int
  1632  
  1633      def __init__(self, bits: int, fill: int = 0):
  1634          self.bits = bits
  1635          self.fill = fill
  1636  
  1637      def size(self, pc: int) -> int:
  1638          mask = (1 << self.bits) - 1
  1639          return (mask - (pc & mask) + 1) & mask
  1640  
  1641      def formatted(self, pc: int) -> str:
  1642          buf = bytes([self.fill]) * self.size(pc)
  1643          return '\t' + Instruction.encode(buf, '.p2align %d, 0x%02x' % (self.bits, self.fill))
  1644      
  1645      def raw_formatted(self, pc: int) -> str:
  1646          buf = bytes([self.fill]) * self.size(pc)
  1647          return Instr.raw_formatted(buf, '.p2align %d, 0x%02x' % (self.bits, self.fill), pc)
  1648  
  1649  REG_MAP = {
  1650      'rax'  : ('MOVQ'  , 'AX'),
  1651      'rdi'  : ('MOVQ'  , 'DI'),
  1652      'rsi'  : ('MOVQ'  , 'SI'),
  1653      'rdx'  : ('MOVQ'  , 'DX'),
  1654      'rcx'  : ('MOVQ'  , 'CX'),
  1655      'r8'   : ('MOVQ'  , 'R8'),
  1656      'r9'   : ('MOVQ'  , 'R9'),
  1657      'xmm0' : ('MOVSD' , 'X0'),
  1658      'xmm1' : ('MOVSD' , 'X1'),
  1659      'xmm2' : ('MOVSD' , 'X2'),
  1660      'xmm3' : ('MOVSD' , 'X3'),
  1661      'xmm4' : ('MOVSD' , 'X4'),
  1662      'xmm5' : ('MOVSD' , 'X5'),
  1663      'xmm6' : ('MOVSD' , 'X6'),
  1664      'xmm7' : ('MOVSD' , 'X7'),
  1665  }
  1666  
  1667  class Counter:
  1668      value: int = 0
  1669  
  1670      @classmethod
  1671      def next(cls) -> int:
  1672          val, cls.value = cls.value, cls.value + 1
  1673          return val
  1674  
  1675  class BasicBlock:
  1676      maxsp: int
  1677      name: str
  1678      weak: bool
  1679      jmptab: bool
  1680      func: bool
  1681      body: List[Instr]
  1682      prevs: List['BasicBlock']
  1683      next: Optional['BasicBlock']
  1684      jump: Optional['BasicBlock']
  1685  
  1686      def __init__(self, name: str, weak: bool = True, jmptab: bool = False, func: bool = False):
  1687          self.maxsp = -1
  1688          self.body = []
  1689          self.prevs = []
  1690          self.name = name
  1691          self.weak = weak
  1692          self.next = None
  1693          self.jump = None
  1694          self.jmptab = jmptab
  1695          self.func = func
  1696              
  1697      def __repr__(self):
  1698          return '{BasicBlock %s}' % repr(self.name)
  1699  
  1700      @property
  1701      def last(self) -> Optional[Instr]:
  1702          return next((v for v in reversed(self.body) if not isinstance(v, CommentInstr)), None)
  1703  
  1704      def size_of(self, pc: int) -> int:
  1705          return functools.reduce(lambda p, v: p + v.size(pc + p), self.body, 0)
  1706  
  1707      def link_to(self, block: 'BasicBlock'):
  1708          self.next = block
  1709          block.prevs.append(self)
  1710  
  1711      def jump_to(self, block: 'BasicBlock'):
  1712          self.jump = block
  1713          block.prevs.append(self)
  1714  
  1715      @classmethod
  1716      def annonymous(cls) -> 'BasicBlock':
  1717          return cls('// bb.%d' % Counter.next(), weak = False)
  1718  
  1719  CLANG_JUMPTABLE_LABLE = 'LJTI'
  1720  
  1721  class CodeSection:
  1722      dead   : bool
  1723      export : bool
  1724      blocks : List[BasicBlock]
  1725      labels : Dict[str, BasicBlock]
  1726      jmptabs: Dict[str, List[BasicBlock]]
  1727      funcs  : Dict[str, Pcsp]
  1728  
  1729      def __init__(self):
  1730          self.dead   = False
  1731          self.labels = {}
  1732          self.export = False
  1733          self.blocks = [BasicBlock.annonymous()]
  1734          self.jmptabs = {}
  1735          self.funcs = {}
  1736      
  1737      @classmethod
  1738      def _dfs_jump_first(cls, bb: BasicBlock, visited: Dict[BasicBlock, bool], hook: Callable[[BasicBlock], bool]) -> bool:
  1739          if bb not in visited or not visited[bb]:
  1740              visited[bb] = True
  1741              if bb.jump and not cls._dfs_jump_first(bb.jump, visited, hook):
  1742                  return False
  1743              if bb.next and not cls._dfs_jump_first(bb.next, visited, hook):
  1744                  return False
  1745              return hook(bb)
  1746          else:
  1747              return True
  1748                  
  1749      def get_jmptab(self, name: str) -> List[BasicBlock]:
  1750          return self.jmptabs.setdefault(name, [])
  1751      
  1752      def get_block(self, name: str) -> BasicBlock:
  1753          for block in self.blocks:
  1754              if block.name == name:
  1755                  return block
  1756  
  1757      @property
  1758      def block(self) -> BasicBlock:
  1759          return self.blocks[-1]
  1760  
  1761      @property
  1762      def instrs(self) -> Iterable[Instr]:
  1763          for block in self.blocks:
  1764              yield from block.body
  1765  
  1766      def _make(self, name: str, jmptab: bool = False, func: bool = False):    
  1767          if func:
  1768          #NOTICE: if it is a function, always set func to be True
  1769              if (old := self.labels.get(name)) and (old.func != func):
  1770                  old.func = True
  1771          return self.labels.setdefault(name, BasicBlock(name, jmptab = jmptab, func = func))
  1772      
  1773      def _next(self, link: BasicBlock):
  1774          if self.dead:
  1775              self.dead = False
  1776          else:
  1777              self.block.link_to(link)
  1778  
  1779      def _decl(self, name: str, block: BasicBlock):
  1780          block.weak = False
  1781          block.body.append(LabelInstr(name))
  1782          self._next(block)
  1783          self.blocks.append(block)
  1784  
  1785      def _kill(self, name: str):
  1786          self.dead = True
  1787          self.block.link_to(self._make(name))
  1788  
  1789      def _split(self, jmp: BasicBlock):
  1790          self.jump = True
  1791          link = BasicBlock.annonymous()
  1792          self.labels[link.name] = link
  1793          self.block.link_to(link)
  1794          self.block.jump_to(jmp)
  1795          self.blocks.append(link)
  1796  
  1797      @staticmethod
  1798      def _mk_align(v: int) -> int:
  1799          if v & 7 == 0:
  1800              return v
  1801          else:
  1802              print('* warning: SP is not aligned with 8 bytes.', file = sys.stderr)
  1803              return (v + 7) & -8
  1804  
  1805      @staticmethod
  1806      def _is_spadj(ins: Instruction) -> bool:
  1807          return len(ins.operands) == 2                 and \
  1808                 isinstance(ins.operands[0], Immediate) and \
  1809                 isinstance(ins.operands[1], Register)  and \
  1810                 ins.operands[1].reg == 'rsp'
  1811  
  1812      @staticmethod
  1813      def _is_spmove(ins: Instruction, i: int) -> bool:
  1814          return len(ins.operands) == 2                and \
  1815                 isinstance(ins.operands[0], Register) and \
  1816                 isinstance(ins.operands[1], Register) and \
  1817                 ins.operands[i].reg == 'rsp'
  1818  
  1819      @staticmethod
  1820      def _is_rjump(ins: Optional[Instr]) -> bool:
  1821          return isinstance(ins, X86Instr) and ins.instr.is_branch_label
  1822  
  1823      def _find_label(self, name: str, adjs: Iterable[int], size: int = 0) -> int:
  1824          for adj, block in zip(adjs, self.blocks):
  1825              if block.name == name:
  1826                  return size
  1827              else:
  1828                  size += block.size_of(size) + adj
  1829          else:
  1830              raise SyntaxError('unresolved reference to name: ' + name)
  1831  
  1832      def _alloc_instr(self, instr: Instruction):
  1833          if not instr.is_branch_label:
  1834              self.block.body.append(X86Instr(instr))
  1835          else:
  1836              self.block.body.append(BranchInstr(instr))
  1837  
  1838      # it seems to not be able to specify stack aligment inside the Go ASM so we
  1839      # need to replace the aligned instructions with unaligned one if either of it's
  1840      # operand is an RBP relative addressing memory operand
  1841  
  1842      __instr_repl__ = {
  1843          'movdqa'  : 'movdqu',
  1844          'movaps'  : 'movups',
  1845          'vmovdqa' : 'vmovdqu',
  1846          'vmovaps' : 'vmovups',
  1847          'vmovapd' : 'vmovupd',
  1848      }
  1849  
  1850      def _check_align(self, instr: Instruction) -> bool:
  1851          if instr.mnemonic in self.__instr_repl__:
  1852              # NOTICE: since we need use unaligned instruction, thus SP can be fixed according to PC
  1853              for op in instr.operands:
  1854                  if isinstance(op, Memory):
  1855                      if op.base is not None and (op.base.reg == 'rbp' or op.base.reg == 'rsp'):
  1856                          instr.mnemonic = self.__instr_repl__[instr.mnemonic]
  1857                          return False
  1858          elif instr.mnemonic == 'andq' and self._is_spadj(instr):
  1859              # NOTICE: since we always use unaligned instruction above, we don't need align SP
  1860              return True
  1861  
  1862      def _check_split(self, instr: Instruction):
  1863          if instr.is_return:
  1864              self.dead = True
  1865              
  1866          elif instr.is_jmpq: # jmpq
  1867              # backtrace jump table from current block (BFS)
  1868              prevs = [self.block]
  1869              visited = set()
  1870              while len(prevs) > 0:
  1871                  curb = prevs.pop()
  1872                  if curb in visited:
  1873                      continue
  1874                  else:
  1875                      visited.add(curb)
  1876                      
  1877                  # backtrace instructions
  1878                  for ins in reversed(curb.body):
  1879                      if isinstance(ins, X86Instr) and ins.instr.jmptab:
  1880                          self._split(self._make(ins.instr.jmptab, jmptab = True))
  1881                          return
  1882                      
  1883                  if curb.prevs:
  1884                      prevs.extend(curb.prevs)
  1885                      
  1886          elif instr.is_branch_label:
  1887              if instr.is_jmp: # jmp
  1888                  self._kill(instr.operands[0].name)
  1889                  
  1890              elif instr.is_invoke: # call
  1891                  fname = instr.operands[0].name
  1892                  self._split(self._make(fname, func = True))
  1893                  
  1894              else: # jeq, ja, jae ...
  1895                  self._split(self._make(instr.operands[0].name)) 
  1896  
  1897      def _trace_block(self, bb: BasicBlock, pcsp: Optional[Pcsp]) -> int:
  1898          if (pcsp is not None):
  1899              if bb.name in self.funcs:
  1900                  # already traced
  1901                  pcsp = None
  1902              else:
  1903                  # continue tracing, update the pcsp
  1904                  # NOTICE: must mark pcsp at block entry because go only calculate delta value
  1905                  pcsp.pc = self.get(bb.name)
  1906                  if bb.func or pcsp.pc < pcsp.entry:  
  1907                      # new func
  1908                      pcsp = Pcsp(pcsp.pc)
  1909                      self.funcs[bb.name] = pcsp
  1910              
  1911          if bb.maxsp == -1:
  1912              ret = self._trace_nocache(bb, pcsp)
  1913              return ret
  1914          elif bb.maxsp >= 0:
  1915              return bb.maxsp
  1916          else:
  1917              return 0
  1918  
  1919      def _trace_nocache(self, bb: BasicBlock, pcsp: Optional[Pcsp]) -> int:
  1920          bb.maxsp = -2
  1921          
  1922          # ## FIXME:
  1923          # if pcsp is None:
  1924          #     pcsp = Pcsp(0)
  1925          
  1926          # make a fake object just for reducing redundant checking
  1927          if pcsp:
  1928              pc0, sp0 = pcsp.pc, pcsp.sp
  1929              
  1930          maxsp, term = self._trace_instructions(bb, pcsp)
  1931  
  1932          # this is a terminating block
  1933          if term:
  1934              return maxsp
  1935  
  1936          # don't trace it's next block if it's an unconditional jump
  1937          a, b = 0, 0
  1938          if pcsp:
  1939              pc, sp = pcsp.pc, pcsp.sp
  1940          
  1941          if bb.jump:
  1942              if bb.jump.jmptab:
  1943                  cases = self.get_jmptab(bb.jump.name)                    
  1944                  for case in cases:
  1945                      nsp = self._trace_block(case, pcsp)
  1946                      if pcsp:
  1947                          pcsp.pc, pcsp.sp = pc, sp
  1948                      if nsp > a:
  1949                          a = nsp
  1950              else:
  1951                  a = self._trace_block(bb.jump, pcsp)
  1952                  if pcsp:
  1953                      pcsp.pc, pcsp.sp = pc, sp
  1954              
  1955          if bb.next: 
  1956              b = self._trace_block(bb.next, pcsp)
  1957          
  1958          if pcsp:
  1959              pcsp.pc, pcsp.sp = pc0, sp0
  1960              
  1961          # select the maximum stack depth
  1962          bb.maxsp = maxsp + max(a, b)
  1963          return bb.maxsp
  1964  
  1965      def _trace_instructions(self, bb: BasicBlock, pcsp: Pcsp) -> Tuple[int, bool]:
  1966          cursp = 0
  1967          maxsp = 0
  1968          close = False
  1969  
  1970          # scan every instruction
  1971          for ins in bb.body:
  1972              diff = 0
  1973              
  1974              if isinstance(ins, X86Instr):
  1975                  name = ins.instr.mnemonic
  1976                  args = ins.instr.operands
  1977  
  1978                  # check for instructions
  1979                  if name == 'retq':
  1980                      close = True
  1981                  elif name == 'popq':
  1982                      diff = -8
  1983                  elif name == 'pushq':
  1984                      diff = 8
  1985                  elif name == 'addq' and self._is_spadj(ins.instr):
  1986                      diff = -self._mk_align(args[0].val)
  1987                  elif name == 'subq' and self._is_spadj(ins.instr):
  1988                      diff = self._mk_align(args[0].val)
  1989                      
  1990                  # FIXME: andq is usually used for aligment of memory address, we can't handle it correctly now
  1991                  # elif name == 'andq' and self._is_spadj(ins.instr): 
  1992                  #     diff = self._mk_align(max(-args[0].val - 8, 0))
  1993                  
  1994                  cursp += diff
  1995                  
  1996                  #NOTICE: pcsp no need to update here
  1997                  if name == 'callq':
  1998                      cursp += 8
  1999                          
  2000                  # update the max stack depth
  2001                  if cursp > maxsp:
  2002                      maxsp = cursp
  2003              
  2004              # update pcsp   
  2005              if pcsp:
  2006                  pcsp.update(ins.size(pcsp.pc), diff)
  2007  
  2008          # trace successful
  2009          return maxsp, close
  2010  
  2011      def get(self, key: str) -> Optional[int]:
  2012          if key not in self.labels:
  2013              raise SyntaxError('unresolved reference to name: ' + key)
  2014          else:
  2015              return self._find_label(key, itertools.repeat(0, len(self.blocks)))
  2016  
  2017      def has(self, key: str) -> bool:
  2018          return key in self.labels
  2019  
  2020      def emit(self, buf: bytes, comments: str = ''):
  2021          if not self.dead:
  2022              self.block.body.append(RawInstr(len(buf), Instruction.encode(buf, comments or buf.hex()), buf))
  2023  
  2024      def lazy(self, size: int, func: Callable[[], int], comments: str = ''):
  2025          if not self.dead:
  2026              self.block.body.append(IntInstr(size, func, comments))
  2027  
  2028      def label(self, name: str):
  2029          if name not in self.labels or self.labels[name].weak:
  2030              self._decl(name, self._make(name))
  2031          else:
  2032              raise SyntaxError('duplicated label: ' + name)
  2033  
  2034      def instr(self, instr: Instruction):
  2035          if not self.dead:
  2036              if self._check_align(instr):
  2037                  return
  2038              self._alloc_instr(instr)
  2039              self._check_split(instr)
  2040  
  2041      def stacksize(self, name: str) -> int:
  2042          if name not in self.labels:
  2043              raise SyntaxError('undefined function: ' + name)
  2044          else:
  2045              return self._trace_block(self.labels[name], None)
  2046          
  2047      def pcsp(self, name: str, entry: int) -> int:
  2048          if name not in self.labels:
  2049              raise SyntaxError('undefined function: ' + name)
  2050          else:
  2051              pcsp = Pcsp(entry)
  2052              self.labels[name].func = True
  2053              return self._trace_block(self.labels[name], pcsp)
  2054          
  2055      def debug(self, pos: int, inss: List[Instruction]):
  2056          def inject(bb: BasicBlock) -> bool:
  2057              if (not bb.func) and (bb.name not in self.funcs):
  2058                  return True
  2059              nonlocal pos
  2060              if pos >= len(bb.body):
  2061                  return
  2062              for ins in inss:
  2063                  bb.body.insert(pos, ins)  
  2064                  pos += 1
  2065          visited = {}
  2066          for _, bb in self.labels.items():
  2067              CodeSection._dfs_jump_first(bb, visited, inject)
  2068  
  2069  STUB_NAME = '__native_entry__'
  2070  STUB_SIZE = 67
  2071  WITH_OFFS = os.getenv('ASM2ASM_DEBUG_OFFSET', '').lower() in ('1', 'yes', 'true')
  2072  
  2073  class Assembler:
  2074      out  : List[str]
  2075      subr : Dict[str, int]
  2076      code : CodeSection
  2077      vals : Dict[str, Union[str, int]]
  2078  
  2079      def __init__(self):
  2080          self.out  = []
  2081          self.subr = {}
  2082          self.vals = {}
  2083          self.code = CodeSection()
  2084  
  2085      def _get(self, v: str) -> int:
  2086          if v not in self.vals:
  2087              return self.code.get(v)
  2088          elif isinstance(self.vals[v], int):
  2089              return self.vals[v]
  2090          else:
  2091              ret = self.vals[v] = self._eval(self.vals[v])
  2092              return ret
  2093  
  2094      def _eval(self, v: str) -> int:
  2095          return Expression(v).eval(self._get)
  2096  
  2097      def _emit(self, v: bytes, cmd: str):
  2098          for i in range(0, len(v), 16):
  2099              self.code.emit(v[i:i + 16], '%s %d, %s' % (cmd, len(v[i:i + 16]), repr(v[i:i + 16])[1:]))
  2100  
  2101      def _limit(self, v: int, a: int, b: int) -> int:
  2102          if not (a <= v <= b):
  2103              raise SyntaxError('integer constant out of bound [%d, %d): %d' % (a, b, v))
  2104          else:
  2105              return v
  2106  
  2107      def _vfill(self, cmd: str, args: List[str]) -> Tuple[int, int]:
  2108          if len(args) == 1:
  2109              return self._limit(self._eval(args[0]), 1, 1 << 64), 0
  2110          elif len(args) == 2:
  2111              return self._limit(self._eval(args[0]), 1, 1 << 64), self._limit(self._eval(args[1]), 0, 255)
  2112          else:
  2113              raise SyntaxError(cmd + ' takes 1 ~ 2 arguments')
  2114  
  2115      def _bytes(self, cmd: str, args: List[str], low: int, high: int, size: int):
  2116          if len(args) != 1:
  2117              raise SyntaxError(cmd + ' takes exact 1 argument')
  2118          else:
  2119              self.code.lazy(size, lambda: self._limit(self._eval(args[0]), low, high) & high, '%s %s' % (cmd, args[0]))
  2120  
  2121      def _comment(self, msg: str):
  2122          self.code.blocks[-1].body.append(CommentInstr(msg))
  2123  
  2124      def _cmd_nop(self, _: List[str]):
  2125          pass
  2126  
  2127      def _cmd_set(self, args: List[str]):
  2128          if len(args) != 2:
  2129              raise SyntaxError(".set takes exact 2 argument")
  2130          elif not args[0].isidentifier():
  2131              raise SyntaxError(repr(args[0]) + " is not a valid identifier")
  2132          else:
  2133              key = args[0]
  2134              val = args[1]
  2135              self.vals[key] = val
  2136              self._comment('.set ' + ', '.join(args))
  2137              # special case: clang-generated jump tables are always like '{block}_{table}'
  2138              jt = val.find(CLANG_JUMPTABLE_LABLE)
  2139              if jt > 0:
  2140                  tab = self.code.get_jmptab(val[jt:])
  2141                  tab.append(self.code.get_block(val[:jt-1]))
  2142  
  2143      def _cmd_byte(self, args: List[str]):
  2144          self._bytes('.byte', args, -0x80, 0xff, 1)
  2145  
  2146      def _cmd_word(self, args: List[str]):
  2147          self._bytes('.word', args, -0x8000, 0xffff, 2)
  2148  
  2149      def _cmd_long(self, args: List[str]):
  2150          self._bytes('.long', args, -0x80000000, 0xffffffff, 4)
  2151  
  2152      def _cmd_quad(self, args: List[str]):
  2153          self._bytes('.quad', args, -0x8000000000000000, 0xffffffffffffffff, 8)
  2154  
  2155      def _cmd_ascii(self, args: List[str]):
  2156          if len(args) != 1:
  2157              raise SyntaxError('.ascii takes exact 1 argument')
  2158          else:
  2159              self._emit(args[0].encode('latin-1'), '.ascii')
  2160  
  2161      def _cmd_asciz(self, args: List[str]):
  2162          if len(args) != 1:
  2163              raise SyntaxError('.asciz takes exact 1 argument')
  2164          else:
  2165              self._emit(args[0].encode('latin-1') + b'\0', '.asciz')
  2166  
  2167      def _cmd_space(self, args: List[str]):
  2168          nb, fv = self._vfill('.space', args)
  2169          self._emit(bytes([fv] * nb), '.space')
  2170  
  2171      def _cmd_p2align(self, args: List[str]):
  2172          if len(args) == 1:
  2173              self.code.block.body.append(AlignmentInstr(self._eval(args[0])))
  2174          elif len(args) == 2:
  2175              self.code.block.body.append(AlignmentInstr(self._eval(args[0]), self._eval(args[1])))
  2176          else:
  2177              raise SyntaxError('.p2align takes 1 ~ 2 arguments')
  2178  
  2179      @functools.cached_property
  2180      def _commands(self) -> dict:
  2181          return {
  2182              '.set'                     : self._cmd_set,
  2183              '.int'                     : self._cmd_long,
  2184              '.long'                    : self._cmd_long,
  2185              '.byte'                    : self._cmd_byte,
  2186              '.quad'                    : self._cmd_quad,
  2187              '.word'                    : self._cmd_word,
  2188              '.hword'                   : self._cmd_word,
  2189              '.short'                   : self._cmd_word,
  2190              '.ascii'                   : self._cmd_ascii,
  2191              '.asciz'                   : self._cmd_asciz,
  2192              '.space'                   : self._cmd_space,
  2193              '.globl'                   : self._cmd_nop,
  2194              '.p2align'                 : self._cmd_p2align,
  2195              '.section'                 : self._cmd_nop,
  2196              '.data_region'             : self._cmd_nop,
  2197              '.build_version'           : self._cmd_nop,
  2198              '.end_data_region'         : self._cmd_nop,
  2199              '.subsections_via_symbols' : self._cmd_nop,
  2200          }
  2201  
  2202      @staticmethod
  2203      def _is_rip_relative(op: Operand) -> bool:
  2204          return isinstance(op, Memory) and \
  2205                 op.base is not None    and \
  2206                 op.base.reg == 'rip'   and \
  2207                 op.index is None       and \
  2208                 isinstance(op.disp, Reference)
  2209  
  2210      @staticmethod
  2211      def _remove_comments(line: str, *, st: str = 'normal') -> str:
  2212          for i, ch in enumerate(line):
  2213              if   st == 'normal' and ch == '/'        : st = 'slcomm'
  2214              elif st == 'normal' and ch == '\"'       : st = 'string'
  2215              elif st == 'normal' and ch in ('#', ';') : return line[:i]
  2216              elif st == 'slcomm' and ch == '/'        : return line[:i - 1]
  2217              elif st == 'slcomm'                      : st = 'normal'
  2218              elif st == 'string' and ch == '\"'       : st = 'normal'
  2219              elif st == 'string' and ch == '\\'       : st = 'escape'
  2220              elif st == 'escape'                      : st = 'string'
  2221          else:
  2222              return line
  2223  
  2224      def _parse(self, src: List[str]):
  2225          for line in src:
  2226              line = self._remove_comments(line)
  2227              line = line.strip()
  2228  
  2229              # skip empty lines
  2230              if not line:
  2231                  continue
  2232  
  2233              # labels, resolve the offset
  2234              if line[-1] == ':':
  2235                  self.code.label(line[:-1])
  2236                  continue
  2237  
  2238              # instructions
  2239              if line[0] != '.':
  2240                  self.code.instr(Instruction.parse(line))
  2241                  continue
  2242  
  2243              # parse the command
  2244              cmd = Command.parse(line)
  2245              func = self._commands.get(cmd.cmd)
  2246  
  2247              # handle the command
  2248              if func is not None:
  2249                  func(cmd.args)
  2250              else:
  2251                  raise SyntaxError('invalid assembly command: ' + cmd.cmd)
  2252  
  2253      def _reloc(self, rip: int = 0):
  2254          for block in self.code.blocks:
  2255              for instr in block.body:
  2256                  rip += self._reloc_one(instr, rip)
  2257  
  2258      def _reloc_one(self, instr: Instr, rip: int) -> int:
  2259          if not isinstance(instr, (X86Instr, BranchInstr)):
  2260              return instr.size(rip)
  2261          elif instr.instr.is_branch_label and isinstance(instr.instr.operands[0], Label):
  2262              return self._reloc_branch(instr.instr, rip)
  2263          else:
  2264              return instr.resize(self._reloc_normal(instr.instr, rip))
  2265  
  2266      def _reloc_branch(self, instr: Instruction, rip: int) -> int:
  2267          instr.operands[0].resolve(self.code.get(instr.operands[0].name) - rip - instr.size)
  2268          return instr.size
  2269  
  2270      def _reloc_normal(self, instr: Instruction, rip: int) -> int:
  2271          msg = []
  2272          ops = instr.operands
  2273  
  2274          # relocate RIP relative operands
  2275          for i, op in enumerate(ops):
  2276              if self._is_rip_relative(op):
  2277                  if self.code.has(str(op.disp.ref)):
  2278                      self._reloc_static(op.disp, msg, rip + instr.size)
  2279                  else:
  2280                      raise SyntaxError('unresolved reference to name ' + str(op.disp.ref))
  2281  
  2282          # attach comments if any
  2283          instr.comments = ', '.join(msg) or instr.comments
  2284          return instr.size
  2285  
  2286      def _reloc_static(self, ref: Reference, msg: List[str], rip: int):
  2287          msg.append('%s+%d(%%rip)' % (ref.ref, ref.disp))
  2288          ref.resolve(self.code.get(str(ref.ref)) - rip)
  2289  
  2290      def _declare(self, protos: PrototypeMap):
  2291          if OUTPUT_RAW:
  2292              self._declare_body_raw()
  2293          else:
  2294              self._declare_body()
  2295          self._declare_functions(protos)
  2296  
  2297      def _declare_body(self):
  2298          self.out.append('TEXT ·%s(SB), NOSPLIT, $0' % STUB_NAME)
  2299          self.out.append('\tNO_LOCAL_POINTERS')
  2300          self._reloc()
  2301  
  2302          # instruction buffer
  2303          pc = 0
  2304          ins = self.code.instrs
  2305  
  2306          # dump every instruction
  2307          for v in ins:
  2308              self.out.append(('// +%d\n' % pc if WITH_OFFS else '') + v.formatted(pc))
  2309              pc += v.size(pc)
  2310              
  2311      def _declare_body_raw(self):
  2312          self._reloc()
  2313  
  2314          # instruction buffer
  2315          pc = 0
  2316          ins = self.code.instrs
  2317  
  2318          # dump every instruction
  2319          for v in ins:
  2320              self.out.append(v.raw_formatted(pc))
  2321              pc += v.size(pc)
  2322  
  2323      def _declare_function(self, name: str, proto: Prototype):
  2324          offs = 0
  2325          subr = name[1:]
  2326          addr = self.code.get(subr)
  2327          self.subr[subr] = addr
  2328          size = self.code.pcsp(subr, addr)        
  2329  
  2330          if OUTPUT_RAW:
  2331              return
  2332          
  2333          # function header and stack checking
  2334          self.out.append('')
  2335          self.out.append('TEXT ·%s(SB), NOSPLIT | NOFRAME, $0 - %d' % (name, proto.argspace))
  2336          self.out.append('\tNO_LOCAL_POINTERS')
  2337          
  2338          # add stack check if needed
  2339          if size != 0:
  2340              self.out.append('')
  2341              self.out.append('_entry:')
  2342              self.out.append('\tMOVQ (TLS), R14')
  2343              self.out.append('\tLEAQ -%d(SP), R12' % size)
  2344              self.out.append('\tCMPQ R12, 16(R14)')
  2345              self.out.append('\tJBE  _stack_grow')
  2346  
  2347          # function name
  2348          self.out.append('')
  2349          self.out.append('%s:' % subr)
  2350  
  2351          # intialize all the arguments
  2352          for arg in proto.args:
  2353              offs += arg.size
  2354              op, reg = REG_MAP[arg.creg.reg]
  2355              self.out.append('\t%s %s+%d(FP), %s' % (op, arg.name, offs - arg.size, reg))
  2356  
  2357          # the function starts at zero
  2358          if addr == 0 and proto.retv is None:
  2359              self.out.append('\tJMP ·%s(SB)  // %s' % (STUB_NAME, subr))
  2360  
  2361          # Go ASM completely ignores the offset of the JMP instruction,
  2362          # so we need to use indirect jumps instead for tail-call elimination
  2363          elif proto.retv is None:
  2364              self.out.append('\tLEAQ ·%s+%d(SB), AX  // %s' % (STUB_NAME, addr, subr))
  2365              self.out.append('\tJMP AX')
  2366  
  2367          # normal functions, call the real function, and return the result
  2368          else:
  2369              self.out.append('\tCALL ·%s+%d(SB)  // %s' % (STUB_NAME, addr, subr))
  2370              self.out.append('\t%s, %s+%d(FP)' % (' '.join(REG_MAP[proto.retv.creg.reg]), proto.retv.name, offs))
  2371              self.out.append('\tRET')
  2372  
  2373          # add stack growing if needed
  2374          if size != 0:
  2375              self.out.append('')
  2376              self.out.append('_stack_grow:')
  2377              self.out.append('\tCALL runtime·morestack_noctxt<>(SB)')
  2378              self.out.append('\tJMP  _entry')
  2379  
  2380      def _declare_functions(self, protos: PrototypeMap):
  2381          for name, proto in sorted(protos.items()):
  2382              if name[0] == '_':
  2383                  self._declare_function(name, proto)
  2384              else:
  2385                  raise SyntaxError('function prototype must have a "_" prefix: ' + repr(name))
  2386  
  2387      def parse(self, src: List[str], proto: PrototypeMap):
  2388          self.code.instr(Instruction('leaq', [Memory(Register('rip'), Immediate(-7), None), Register('rax')]))
  2389          self.code.instr(Instruction('movq', [Register('rax'), Memory(Register('rsp'), Immediate(8), None)]))
  2390          self.code.instr(Instruction('retq', []))
  2391          self._parse(src)
  2392          # print("DEBUG...")
  2393          # self.code.debug(0, [
  2394          #     X86Instr(Instruction('int3', []))
  2395          #     # X86Instr(Instruction('xorq', [Register('rax'), Register('rax')])),
  2396          #     # X86Instr(Instruction('movq', [Memory(Register('rax'), Immediate(0), None), Register('rax')]))
  2397          # ])
  2398          self._declare(proto)
  2399  
  2400  GOOS = {
  2401      'aix',
  2402      'android',
  2403      'darwin',
  2404      'dragonfly',
  2405      'freebsd',
  2406      'hurd',
  2407      'illumos',
  2408      'js',
  2409      'linux',
  2410      'nacl',
  2411      'netbsd',
  2412      'openbsd',
  2413      'plan9',
  2414      'solaris',
  2415      'windows',
  2416      'zos',
  2417  }
  2418  
  2419  GOARCH = {
  2420      '386',
  2421      'amd64',
  2422      'amd64p32',
  2423      'arm',
  2424      'armbe',
  2425      'arm64',
  2426      'arm64be',
  2427      'ppc64',
  2428      'ppc64le',
  2429      'mips',
  2430      'mipsle',
  2431      'mips64',
  2432      'mips64le',
  2433      'mips64p32',
  2434      'mips64p32le',
  2435      'ppc',
  2436      'riscv',
  2437      'riscv64',
  2438      's390',
  2439      's390x',
  2440      'sparc',
  2441      'sparc64',
  2442      'wasm',
  2443  }
  2444  
  2445  def make_subr_filename(name: str) -> str:
  2446      name = os.path.basename(name)
  2447      base = os.path.splitext(name)[0].rsplit('_', 2)
  2448  
  2449      # construct the new name
  2450      if base[-1] in GOOS:
  2451          return '%s_subr_%s.go' % ('_'.join(base[:-1]), base[-1])
  2452      elif base[-1] not in GOARCH:
  2453          return '%s_subr.go' % '_'.join(base)
  2454      elif len(base) > 2 and base[-2] in GOOS:
  2455          return '%s_subr_%s_%s.go' % ('_'.join(base[:-2]), base[-2], base[-1])
  2456      else:
  2457          return '%s_subr_%s.go' % ('_'.join(base[:-1]), base[-1])
  2458  
  2459  def main():
  2460      src = []
  2461      asm = Assembler()
  2462      
  2463      
  2464      # check for arguments
  2465      if len(sys.argv) < 3:
  2466          print('* usage: %s [-r|-d] <output-file> <clang-asm> ...' % sys.argv[0], file = sys.stderr)
  2467          sys.exit(1)
  2468  
  2469      # check if optional flag is enabled
  2470      global OUTPUT_RAW
  2471      OUTPUT_RAW = False
  2472      if len(sys.argv) >= 4:
  2473          i = 0
  2474          while i<len(sys.argv):
  2475              flag = sys.argv[i]
  2476              if flag == '-r':
  2477                  OUTPUT_RAW = True
  2478                  for j in range(i, len(sys.argv)-1):
  2479                      sys.argv[j] = sys.argv[j + 1]  
  2480                  sys.argv.pop()
  2481                  continue
  2482              i += 1
  2483              
  2484      # parse the prototype
  2485      with open(os.path.splitext(sys.argv[1])[0] + '.go', 'r', newline = None) as fp:
  2486          pkg, proto = PrototypeMap.parse(fp.read())
  2487  
  2488      # read all the sources, and combine them together
  2489      for fn in sys.argv[2:]:
  2490          with open(fn, 'r', newline = None) as fp:
  2491              src.extend(fp.read().splitlines())
  2492  
  2493      # convert the original sources
  2494      if OUTPUT_RAW:
  2495          asm.out.append('// +build amd64')
  2496          asm.out.append('// Code generated by asm2asm, DO NOT EDIT.')
  2497          asm.out.append('')
  2498          asm.out.append('package %s' % pkg)
  2499          asm.out.append('')
  2500          ## native text
  2501          asm.out.append('var Text%s = []byte{' % STUB_NAME)
  2502      else:
  2503          asm.out.append('// +build !noasm !appengine')
  2504          asm.out.append('// Code generated by asm2asm, DO NOT EDIT.')
  2505          asm.out.append('')
  2506          asm.out.append('#include "go_asm.h"')
  2507          asm.out.append('#include "funcdata.h"')
  2508          asm.out.append('#include "textflag.h"')
  2509          asm.out.append('')
  2510          
  2511      asm.parse(src, proto)
  2512  
  2513      if OUTPUT_RAW:
  2514          asrc = os.path.splitext(sys.argv[1])[0]
  2515          asrc = asrc[:asrc.rfind('_')] + '_text_amd64.go'
  2516      else:
  2517          asrc = os.path.splitext(sys.argv[1])[0] + '.s'
  2518        
  2519      # save the converted result  
  2520      with open(asrc, 'w')  as fp:
  2521          for line in asm.out:
  2522              print(line, file = fp)
  2523          if OUTPUT_RAW:
  2524              print('}', file = fp)
  2525  
  2526      # calculate the subroutine stub file name
  2527      subr = make_subr_filename(sys.argv[1])
  2528      subr = os.path.join(os.path.dirname(sys.argv[1]), subr)
  2529  
  2530      # save the compiled code stub
  2531      with open(subr, 'w') as fp:
  2532          print('// +build !noasm !appengine', file = fp)
  2533          print('// Code generated by asm2asm, DO NOT EDIT.', file = fp)
  2534          print(file = fp)
  2535          print('package %s' % pkg, file = fp)
  2536                
  2537          # also save the actual function addresses if any
  2538          if not asm.subr:
  2539              return 
  2540          
  2541          if OUTPUT_RAW:
  2542              print(file = fp)
  2543              print('import (\n\t`github.com/goshafaq/sonic/loader`\n)', file = fp)
  2544              
  2545              # dump every entry for all functions
  2546              print(file = fp)
  2547              print('const (', file = fp)
  2548              for name in asm.code.funcs.keys():
  2549                  addr = asm.code.get(name)
  2550                  if addr is not None:
  2551                      print(f'    _entry_{name} = %d' % addr, file = fp)
  2552              print(')', file = fp)
  2553              
  2554              # dump max stack depth for all functions
  2555              print(file = fp)
  2556              print('const (', file = fp)
  2557              for name in asm.code.funcs.keys():
  2558                  print('    _stack_%s = %d' % (name, asm.code.stacksize(name)), file = fp)
  2559              print(')', file = fp)
  2560  
  2561              # dump every text size for all functions
  2562              print(file = fp)
  2563              print('const (', file = fp)
  2564              for name, pcsp in asm.code.funcs.items():
  2565                  if pcsp is not None:
  2566                      # print(f'before {name} optimize {pcsp}')
  2567                      pcsp.optimize()
  2568                      # print(f'after {name} optimize {pcsp}')
  2569                      print(f'    _size_{name} = %d' % (pcsp.maxpc - pcsp.entry), file = fp)
  2570              print(')', file = fp)
  2571              
  2572              # dump every pcsp for all functions
  2573              print(file = fp)
  2574              print('var (', file = fp)
  2575              for name, pcsp in asm.code.funcs.items():
  2576                  if pcsp is not None:
  2577                      print(f'    _pcsp_{name} = %s' % pcsp, file = fp)
  2578              print(')', file = fp)
  2579              
  2580              # insert native entry info
  2581              print(file = fp)
  2582              print('var Funcs = []loader.CFunc{', file = fp)
  2583              print('    {"%s", 0, %d, 0, nil},' % (STUB_NAME, STUB_SIZE), file = fp)
  2584              # dump every native function info for all functions
  2585              for name in asm.code.funcs.keys():
  2586                  print('    {"%s", _entry_%s, _size_%s, _stack_%s, _pcsp_%s},' % (name, name, name, name, name), file = fp)
  2587              print('}', file = fp)
  2588  
  2589          else:
  2590              # native entry for entry function
  2591              print(file = fp)
  2592              print('//go:nosplit', file = fp)
  2593              print('//go:noescape', file = fp)
  2594              print('//goland:noinspection ALL', file = fp)
  2595              print('func %s() uintptr' % STUB_NAME, file = fp)
  2596              
  2597              # dump exported function entry for exported functions
  2598              print(file = fp)
  2599              print('var (', file = fp)
  2600              mlen = max(len(s) for s in asm.subr)
  2601              for name, entry in asm.subr.items():
  2602                  print('    _subr_%s uintptr = %s() + %d' % (name.ljust(mlen, ' '), STUB_NAME, entry), file = fp)
  2603              print(')', file = fp)
  2604  
  2605              # dump max stack depth for exported functions
  2606              print(file = fp)
  2607              print('const (', file = fp)
  2608              for name in asm.subr.keys():
  2609                  print('    _stack_%s = %d' % (name, asm.code.stacksize(name)), file = fp)
  2610              print(')', file = fp)
  2611  
  2612              # assign subroutine offsets to '_' to mute the "unused" warnings
  2613              print(file = fp)
  2614              print('var (', file = fp)
  2615              for name in asm.subr:
  2616                  print('    _ = _subr_%s' % name, file = fp)
  2617              print(')', file = fp)
  2618              
  2619              # dump every constant
  2620              print(file = fp)
  2621              print('const (', file = fp)
  2622              for name in asm.subr:
  2623                  print('    _ = _stack_%s' % name, file = fp)
  2624              else:
  2625                  print(')', file = fp)
  2626  
  2627  if __name__ == '__main__':
  2628      main()