github.com/goshafaq/sonic@v0.0.0-20231026082336-871835fb94c6/tools/asm2asm/asm2asm.py (about) 1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 4 import os 5 import sys 6 import string 7 import itertools 8 import functools 9 10 from typing import Any 11 from typing import Dict 12 from typing import List 13 from typing import Type 14 from typing import Tuple 15 from typing import Union 16 from typing import Callable 17 from typing import Iterable 18 from typing import Optional 19 20 from peachpy import x86_64 21 from peachpy.x86_64 import generic 22 from peachpy.x86_64 import XMMRegister 23 from peachpy.x86_64.operand import is_rel32 24 from peachpy.x86_64.operand import MemoryOperand 25 from peachpy.x86_64.operand import MemoryAddress 26 from peachpy.x86_64.operand import RIPRelativeOffset 27 from peachpy.x86_64.instructions import Instruction as PInstr 28 from peachpy.x86_64.instructions import BranchInstruction 29 30 ### Instruction Parser (GAS Syntax) ### 31 32 class Label: 33 name: str 34 offs: Optional[int] 35 36 def __init__(self, name: str): 37 self.name = name 38 self.offs = None 39 40 def __str__(self): 41 return self.name 42 43 def __repr__(self): 44 if self.offs is None: 45 return '{LABEL %s (unresolved)}' % self.name 46 else: 47 return '{LABEL %s (offset: %d)}' % (self.name, self.offs) 48 49 def resolve(self, offs: int): 50 self.offs = offs 51 52 class Index: 53 base : 'Register' 54 scale : int 55 56 def __init__(self, base: 'Register', scale: int = 1): 57 self.base = base 58 self.scale = scale 59 60 def __str__(self): 61 if self.scale == 1: 62 return ',%s' % self.base 63 elif self.scale >= 2: 64 return ',%s,%d' % (self.base, self.scale) 65 else: 66 raise RuntimeError('invalid parser state: invalid scale') 67 68 def __repr__(self): 69 if self.scale == 1: 70 return repr(self.base) 71 elif self.scale >= 2: 72 return '%d * %r' % (self.scale, self.base) 73 else: 74 raise RuntimeError('invalid parser state: invalid scale') 75 76 class Memory: 77 base : Optional['Register'] 78 disp : Optional['Displacement'] 79 index : Optional[Index] 80 81 def __init__(self, base: Optional['Register'], disp: Optional['Displacement'], index: Optional[Index]): 82 self.base = base 83 self.disp = disp 84 self.index = index 85 self._validate() 86 87 def __str__(self): 88 return '%s(%s%s)' % ( 89 '' if self.disp is None else self.disp, 90 '' if self.base is None else self.base, 91 '' if self.index is None else self.index 92 ) 93 94 def __repr__(self): 95 return '{MEM %r%s%s}' % ( 96 '' if self.base is None else self.base, 97 '' if self.index is None else ' + ' + repr(self.index), 98 '' if self.disp is None else ' + ' + repr(self.disp) 99 ) 100 101 def _validate(self): 102 if self.base is None and self.index is None: 103 raise SyntaxError('either base or index must be specified') 104 105 class Register: 106 reg: str 107 108 def __init__(self, reg: str): 109 self.reg = reg.lower() 110 111 def __str__(self): 112 return '%' + self.reg 113 114 def __repr__(self): 115 return '{REG %s}' % self.reg 116 117 @functools.cached_property 118 def native(self) -> x86_64.registers.Register: 119 if self.reg == 'rip': 120 raise SyntaxError('%rip is not directly accessible') 121 else: 122 return getattr(x86_64.registers, self.reg) 123 124 class Immediate: 125 val: int 126 ref: str 127 128 def __init__(self, val: int): 129 self.ref = '' 130 self.val = val 131 132 def __str__(self): 133 return '$%d' % self.val 134 135 def __repr__(self): 136 return '{IMM bin:%s, oct:%s, dec:%d, hex:%s}' % ( 137 bin(self.val)[2:], 138 oct(self.val)[2:], 139 self.val, 140 hex(self.val)[2:], 141 ) 142 143 class Reference: 144 ref: str 145 disp: int 146 off: Optional[int] 147 148 def __init__(self, ref: str, disp: int = 0): 149 self.ref = ref 150 self.disp = disp 151 self.off = None 152 153 def __str__(self): 154 if self.off is None: 155 return self.ref 156 else: 157 return '$' + str(self.off) 158 159 def __repr__(self): 160 if self.off is None: 161 return '{REF %s + %d (unresolved)}' % (self.ref, self.disp) 162 else: 163 return '{REF %s + %d (offset: %d)}' % (self.ref, self.disp, self.off) 164 165 @property 166 def offset(self) -> int: 167 if self.off is None: 168 raise SyntaxError('unresolved reference to ' + repr(self.ref)) 169 else: 170 return self.off 171 172 def resolve(self, off: int): 173 self.off = self.disp + off 174 175 Operand = Union[ 176 Label, 177 Memory, 178 Register, 179 Immediate, 180 Reference, 181 ] 182 183 Displacement = Union[ 184 Immediate, 185 Reference, 186 ] 187 188 TOKEN_END = 0 189 TOKEN_REG = 1 190 TOKEN_IMM = 2 191 TOKEN_NUM = 3 192 TOKEN_NAME = 4 193 TOKEN_PUNC = 5 194 195 REGISTERS = { 196 'rax' , 'eax' , 'ax' , 'al' , 'ah' , 197 'rbx' , 'ebx' , 'bx' , 'bl' , 'bh' , 198 'rcx' , 'ecx' , 'cx' , 'cl' , 'ch' , 199 'rdx' , 'edx' , 'dx' , 'dl' , 'dh' , 200 'rsi' , 'esi' , 'si' , 'sil' , 201 'rdi' , 'edi' , 'di' , 'dil' , 202 'rbp' , 'ebp' , 'bp' , 'bpl' , 203 'rsp' , 'esp' , 'sp' , 'spl' , 204 'r8' , 'r8d' , 'r8w' , 'r8b' , 205 'r9' , 'r9d' , 'r9w' , 'r9b' , 206 'r10' , 'r10d' , 'r10w' , 'r10b' , 207 'r11' , 'r11d' , 'r11w' , 'r11b' , 208 'r12' , 'r12d' , 'r12w' , 'r12b' , 209 'r13' , 'r13d' , 'r13w' , 'r13b' , 210 'r14' , 'r14d' , 'r14w' , 'r14b' , 211 'r15' , 'r15d' , 'r15w' , 'r15b' , 212 'mm0' , 'mm1' , 'mm2' , 'mm3' , 'mm4' , 'mm5' , 'mm6' , 'mm7' , 213 'xmm0' , 'xmm1' , 'xmm2' , 'xmm3' , 'xmm4' , 'xmm5' , 'xmm6' , 'xmm7' , 214 'xmm8' , 'xmm9' , 'xmm10' , 'xmm11' , 'xmm12' , 'xmm13' , 'xmm14' , 'xmm15' , 215 'xmm16' , 'xmm17' , 'xmm18' , 'xmm19' , 'xmm20' , 'xmm21' , 'xmm22' , 'xmm23' , 216 'xmm24' , 'xmm25' , 'xmm26' , 'xmm27' , 'xmm28' , 'xmm29' , 'xmm30' , 'xmm31' , 217 'ymm0' , 'ymm1' , 'ymm2' , 'ymm3' , 'ymm4' , 'ymm5' , 'ymm6' , 'ymm7' , 218 'ymm8' , 'ymm9' , 'ymm10' , 'ymm11' , 'ymm12' , 'ymm13' , 'ymm14' , 'ymm15' , 219 'ymm16' , 'ymm17' , 'ymm18' , 'ymm19' , 'ymm20' , 'ymm21' , 'ymm22' , 'ymm23' , 220 'ymm24' , 'ymm25' , 'ymm26' , 'ymm27' , 'ymm28' , 'ymm29' , 'ymm30' , 'ymm31' , 221 'zmm0' , 'zmm1' , 'zmm2' , 'zmm3' , 'zmm4' , 'zmm5' , 'zmm6' , 'zmm7' , 222 'zmm8' , 'zmm9' , 'zmm10' , 'zmm11' , 'zmm12' , 'zmm13' , 'zmm14' , 'zmm15' , 223 'zmm16' , 'zmm17' , 'zmm18' , 'zmm19' , 'zmm20' , 'zmm21' , 'zmm22' , 'zmm23' , 224 'zmm24' , 'zmm25' , 'zmm26' , 'zmm27' , 'zmm28' , 'zmm29' , 'zmm30' , 'zmm31' , 225 'rip' , 226 } 227 228 class Token: 229 tag: int 230 val: Union[int, str] 231 232 def __init__(self, tag: int, val: Union[int, str]): 233 self.val = val 234 self.tag = tag 235 236 @classmethod 237 def end(cls): 238 return cls(TOKEN_END, '') 239 240 @classmethod 241 def reg(cls, reg: str): 242 return cls(TOKEN_REG, reg) 243 244 @classmethod 245 def imm(cls, imm: int): 246 return cls(TOKEN_IMM, imm) 247 248 @classmethod 249 def num(cls, num: int): 250 return cls(TOKEN_NUM, num) 251 252 @classmethod 253 def name(cls, name: str): 254 return cls(TOKEN_NAME, name) 255 256 @classmethod 257 def punc(cls, punc: str): 258 return cls(TOKEN_PUNC, punc) 259 260 def __repr__(self): 261 if self.tag == TOKEN_END: 262 return '<END>' 263 elif self.tag == TOKEN_REG: 264 return '<REG %s>' % self.val 265 elif self.tag == TOKEN_IMM: 266 return '<IMM %d>' % self.val 267 elif self.tag == TOKEN_NUM: 268 return '<NUM %d>' % self.val 269 elif self.tag == TOKEN_NAME: 270 return '<NAME %s>' % repr(self.val) 271 elif self.tag == TOKEN_PUNC: 272 return '<PUNC %s>' % repr(self.val) 273 else: 274 return '<UNK:%d %r>' % (self.tag, self.val) 275 276 class Tokenizer: 277 pos: int 278 src: str 279 280 def __init__(self, src: str): 281 self.pos = 0 282 self.src = src 283 284 @property 285 def _ch(self) -> str: 286 return self.src[self.pos] 287 288 @property 289 def _eof(self) -> bool: 290 return self.pos >= len(self.src) 291 292 def _rch(self) -> str: 293 ret, self.pos = self.src[self.pos], self.pos + 1 294 return ret 295 296 def _rid(self, s: str, allow_dot: bool) -> str: 297 while not self._eof and (self._ch == '_' or self._ch.isalnum() or (allow_dot and self._ch == '.')): 298 s += self._rch() 299 else: 300 return s 301 302 def _reg(self) -> Token: 303 if self._eof: 304 raise SyntaxError('unexpected EOF when parsing register names') 305 else: 306 return self._regx() 307 308 def _imm(self) -> Token: 309 if self._eof: 310 raise SyntaxError('unexpected EOF when parsing immediate values') 311 else: 312 return self._immx(self._rch()) 313 314 def _regx(self) -> Token: 315 nch = self._rch() 316 reg = self._rid(nch, allow_dot = False).lower() 317 318 # check for register names 319 if reg not in REGISTERS: 320 raise SyntaxError('invalid register: ' + reg) 321 else: 322 return Token.reg(reg) 323 324 def _immv(self, ch: str) -> int: 325 while not self._eof and self._ch in string.digits: 326 ch += self._rch() 327 else: 328 return int(ch) 329 330 def _immx(self, ch: str) -> Token: 331 if ch.isdigit(): 332 return Token.imm(self._immv(ch)) 333 elif ch == '-': 334 return Token.imm(-self._immv(self._rch())) 335 else: 336 raise SyntaxError('unexpected character when parsing immediate value: ' + ch) 337 338 def _name(self, ch: str) -> Token: 339 return Token.name(self._rid(ch, allow_dot = True)) 340 341 def _read(self, ch: str) -> Token: 342 if ch == '%': 343 return self._reg() 344 elif ch == '$': 345 return self._imm() 346 elif ch == '-': 347 return Token.num(-self._immv(self._rch())) 348 elif ch == '+': 349 return Token.num(self._immv(self._rch())) 350 elif ch.isdigit(): 351 return Token.num(self._immv(ch)) 352 elif ch.isidentifier(): 353 return self._name(ch) 354 elif ch in ('(', ')', ',', '*'): 355 return Token.punc(ch) 356 else: 357 raise SyntaxError('invalid character: ' + repr(ch)) 358 359 def next(self) -> Token: 360 while not self._eof and self._ch.isspace(): 361 self.pos += 1 362 else: 363 return Token.end() if self._eof else self._read(self._rch()) 364 365 class Instruction: 366 comments: str 367 mnemonic: str 368 operands: List[Operand] 369 370 def __init__(self, mnemonic: str, operands: List[Operand]): 371 self.comments = '' 372 self.operands = operands 373 self.mnemonic = mnemonic.lower() 374 375 def __str__(self): 376 ops = ', '.join(map(str, self.operands)) 377 com = self.comments and ' /* %s */' % self.comments 378 379 # ordinal instructions 380 if not self.is_branch: 381 return '%-12s %s%s' % (self.mnemonic, ops, com) 382 elif len(self.operands) != 1: 383 raise SyntaxError('invalid branch instruction: ' + self.mnemonic) 384 elif isinstance(self.operands[0], Label): 385 return '%-12s %s%s' % (self.mnemonic, ops, com) 386 else: 387 return '%-12s *%s%s' % (self.mnemonic, ops, com) 388 389 def __repr__(self): 390 return '{INSTR %s: %s%s}' % ( 391 self.mnemonic, 392 ', '.join(map(repr, self.operands)), 393 self.comments and ' (%s)' % self.comments 394 ) 395 396 class Basic: 397 @staticmethod 398 def INT3(*args, **kwargs): 399 return x86_64.INT(3, *args, **kwargs) 400 401 @staticmethod 402 def MOVQ(*args, **kwargs): 403 if not any(isinstance(v, XMMRegister) for v in args): 404 return x86_64.MOV(*args, **kwargs) 405 else: 406 return x86_64.MOVQ(*args, **kwargs) 407 408 class BitShift: 409 op: Type[PInstr] 410 411 def __init__(self, op: Type[PInstr]): 412 self.op = op 413 414 def __call__(self, *args, **kwargs): 415 if len(args) != 1: 416 return self.op(*args, **kwargs) 417 else: 418 return self.op(*args, 1, **kwargs) 419 420 class VectorCompare: 421 fn: int 422 op: Type[PInstr] 423 424 def __init__(self, op: Type[PInstr], fn: int): 425 self.fn = fn 426 self.op = op 427 428 def __call__(self, *args, **kwargs): 429 return self.op(*args, self.fn, **kwargs) 430 431 __instr_map__ = { 432 'INT3' : Basic.INT3, 433 'SALB' : BitShift(x86_64.SAL), 434 'SALW' : BitShift(x86_64.SAL), 435 'SALL' : BitShift(x86_64.SAL), 436 'SALQ' : BitShift(x86_64.SAL), 437 'SARB' : BitShift(x86_64.SAR), 438 'SARW' : BitShift(x86_64.SAR), 439 'SARL' : BitShift(x86_64.SAR), 440 'SARQ' : BitShift(x86_64.SAR), 441 'SHLB' : BitShift(x86_64.SHL), 442 'SHLW' : BitShift(x86_64.SHL), 443 'SHLL' : BitShift(x86_64.SHL), 444 'SHLQ' : BitShift(x86_64.SHL), 445 'SHRB' : BitShift(x86_64.SHR), 446 'SHRW' : BitShift(x86_64.SHR), 447 'SHRL' : BitShift(x86_64.SHR), 448 'SHRQ' : BitShift(x86_64.SHR), 449 'MOVQ' : Basic.MOVQ, 450 'CBTW' : x86_64.CBW, 451 'CWTL' : x86_64.CWDE, 452 'CLTQ' : x86_64.CDQE, 453 'MOVZBW' : x86_64.MOVZX, 454 'MOVZBL' : x86_64.MOVZX, 455 'MOVZWL' : x86_64.MOVZX, 456 'MOVZBQ' : x86_64.MOVZX, 457 'MOVZWQ' : x86_64.MOVZX, 458 'MOVSBW' : x86_64.MOVSX, 459 'MOVSBL' : x86_64.MOVSX, 460 'MOVSWL' : x86_64.MOVSX, 461 'MOVSBQ' : x86_64.MOVSX, 462 'MOVSWQ' : x86_64.MOVSX, 463 'MOVSLQ' : x86_64.MOVSXD, 464 'MOVABSQ' : x86_64.MOV, 465 'VCMPEQPS' : VectorCompare(x86_64.VCMPPS, 0x00), 466 'VCMPTRUEPS' : VectorCompare(x86_64.VCMPPS, 0x0f), 467 } 468 469 @functools.cached_property 470 def _instr(self) -> Union[Type[PInstr], Callable[..., PInstr]]: 471 name = self.mnemonic.upper() 472 func = self.__instr_map__.get(name) 473 474 # not found, resolve as x86_64 instruction 475 if func is None: 476 func = getattr(x86_64, name, None) 477 478 # try with size suffix removed (only for generic instructions) 479 if func is None and name[-1] in 'BWLQ': 480 func = getattr(generic, name[:-1], func) 481 482 # still not found, it should be an error 483 if func is None: 484 raise SyntaxError('unknown instruction: ' + self.mnemonic) 485 else: 486 return func 487 488 @property 489 def jmptab(self) -> Optional[str]: 490 if self.mnemonic == 'leaq' and isinstance(self.operands[0], Memory) and self.operands[0].base.reg == 'rip': 491 dis = self.operands[0].disp 492 if dis and dis.ref.find(CLANG_JUMPTABLE_LABLE) != -1: 493 return dis.ref 494 495 @property 496 def _instr_size(self) -> Optional[int]: 497 ops = self.operands 498 key = self.mnemonic.upper() 499 500 # special case of sign/zero extension instructions 501 if key in self.__instr_size__: 502 return self.__instr_size__[key] 503 504 # check for register operands 505 for op in ops: 506 if isinstance(op, Register): 507 return None 508 509 # check for size suffix, and this only applies to generic instructions 510 if key[-1] not in self.__size_map__ or not hasattr(generic, key[:-1]): 511 raise SyntaxError('ambiguous operand sizes') 512 else: 513 return self.__size_map__[key[-1]] 514 515 __size_map__ = { 516 'B': 1, 517 'W': 2, 518 'L': 4, 519 'Q': 8, 520 } 521 522 __instr_size__ = { 523 'MOVZBW' : 1, 524 'MOVZBL' : 1, 525 'MOVZWL' : 2, 526 'MOVZBQ' : 1, 527 'MOVZWQ' : 2, 528 'MOVSBW' : 1, 529 'MOVSBL' : 1, 530 'MOVSWL' : 2, 531 'MOVSBQ' : 1, 532 'MOVSWQ' : 2, 533 'MOVSLQ' : 4, 534 } 535 536 @staticmethod 537 def _encode_r32(ins: PInstr) -> bytes: 538 ret = [fn(ins.operands) for _, fn in ins.encodings] 539 ret.sort(key = len) 540 return ret[-1] 541 542 @classmethod 543 def _encode_ins(cls, ins: PInstr, force_rel32: bool = False) -> bytes: 544 if not isinstance(ins, BranchInstruction): 545 return ins.encode() 546 elif not is_rel32(ins.operands[0]) or not force_rel32: 547 return ins.encode() 548 else: 549 return cls._encode_r32(ins) 550 551 def _encode_rel(self, rel: Label, sizing: bool, offset: int) -> RIPRelativeOffset: 552 if rel.offs is not None: 553 return RIPRelativeOffset(rel.offs) 554 elif sizing: 555 return RIPRelativeOffset(offset) 556 else: 557 raise SyntaxError('unresolved reference to name: ' + rel.name) 558 559 def _encode_mem(self, mem: Memory, sizing: bool, offset: int) -> MemoryOperand: 560 if mem.base is not None and mem.base.reg == 'rip': 561 return self._encode_mem_rip(mem, sizing, offset) 562 else: 563 return self._encode_mem_std(mem, sizing, offset) 564 565 def _encode_mem_rip(self, mem: Memory, sizing: bool, offset: int) -> MemoryOperand: 566 if mem.disp is None: 567 return MemoryOperand(RIPRelativeOffset(0)) 568 elif mem.index is not None: 569 raise SyntaxError('%rip relative addresing does not support indexing') 570 elif isinstance(mem.disp, Immediate): 571 return MemoryOperand(RIPRelativeOffset(mem.disp.val)) 572 elif isinstance(mem.disp, Reference): 573 return MemoryOperand(RIPRelativeOffset(offset if sizing else mem.disp.offset)) 574 else: 575 raise RuntimeError('illegal memory displacement') 576 577 def _encode_mem_std(self, mem: Memory, sizing: bool, offset: int) -> MemoryOperand: 578 disp = 0 579 base = None 580 index = None 581 scale = None 582 583 # add optional base 584 if mem.base is not None: 585 base = mem.base.native 586 587 # add optional indexing 588 if mem.index is not None: 589 scale = mem.index.scale 590 index = mem.index.base.native 591 592 # add optional displacement 593 if mem.disp is not None: 594 if isinstance(mem.disp, Immediate): 595 disp = mem.disp.val 596 elif isinstance(mem.disp, Reference): 597 disp = offset if sizing else mem.disp.offset 598 else: 599 raise RuntimeError('illegal memory displacement') 600 601 # construct the memory address 602 return MemoryOperand( 603 size = self._instr_size, 604 address = MemoryAddress(base, index, scale, disp), 605 ) 606 607 def _encode_operands(self, sizing: bool, offset: int) -> Iterable[Any]: 608 for op in self.operands: 609 if isinstance(op, Label): 610 yield self._encode_rel(op, sizing, offset) 611 elif isinstance(op, Memory): 612 yield self._encode_mem(op, sizing, offset) 613 elif isinstance(op, Register): 614 yield op.native 615 elif isinstance(op, Immediate): 616 yield op.val 617 else: 618 raise SyntaxError('cannot encode %s as operand' % repr(op)) 619 620 def _encode_branch_rel(self, rel: Label) -> str: 621 if rel.offs is not None: 622 return self._encode_normal_instr() 623 else: 624 raise RuntimeError('invalid relative branching instruction') 625 626 def _raw_branch_rel(self, rel: Label) -> bytes: 627 if rel.offs is not None: 628 return self._raw_normal_instr() 629 else: 630 raise RuntimeError('invalid relative branching instruction') 631 632 def _encode_branch_mem(self, mem: Memory) -> str: 633 raise NotImplementedError('not implemented: memory indirect jump') 634 635 def _raw_branch_mem(self, mem: Memory) -> bytes: 636 raise NotImplementedError('not implemented: memory indirect jump') 637 638 def _encode_branch_reg(self, reg: Register) -> str: 639 if reg.reg == 'rip': 640 raise SyntaxError('%rip cannot be used as a jump target') 641 elif self.mnemonic != 'jmpq': 642 raise SyntaxError('invalid indirect jump for instruction: ' + self.mnemonic) 643 else: 644 return x86_64.JMP(reg.native).format('go') 645 646 def _raw_branch_reg(self, reg: Register) -> bytes: 647 if reg.reg == 'rip': 648 raise SyntaxError('%rip cannot be used as a jump target') 649 elif self.mnemonic != 'jmpq': 650 raise SyntaxError('invalid indirect jump for instruction: ' + self.mnemonic) 651 else: 652 return x86_64.JMP(reg.native).encode() 653 654 def _encode_branch_instr(self) -> str: 655 if len(self.operands) != 1: 656 raise RuntimeError('illegal branch instruction') 657 elif isinstance(self.operands[0], Label): 658 return self._encode_branch_rel(self.operands[0]) 659 elif isinstance(self.operands[0], Memory): 660 return self._encode_branch_mem(self.operands[0]) 661 elif isinstance(self.operands[0], Register): 662 return self._encode_branch_reg(self.operands[0]) 663 else: 664 raise RuntimeError('invalid operand type ' + repr(self.operands[0])) 665 666 def _raw_branch_instr(self) -> str: 667 if len(self.operands) != 1: 668 raise RuntimeError('illegal branch instruction') 669 elif isinstance(self.operands[0], Label): 670 return self._raw_branch_rel(self.operands[0]) 671 elif isinstance(self.operands[0], Memory): 672 return self._raw_branch_mem(self.operands[0]) 673 elif isinstance(self.operands[0], Register): 674 return self._raw_branch_reg(self.operands[0]) 675 else: 676 raise RuntimeError('invalid operand type ' + repr(self.operands[0])) 677 678 def _encode_normal_instr(self) -> str: 679 ops = self._encode_operands(False, 0) 680 ret = self._instr(*list(ops)[::-1]) 681 682 # encode all instructions as raw bytes 683 if not self.is_branch_label: 684 return self.encode(self._encode_ins(ret), str(self)) 685 else: 686 return self.encode(self._encode_ins(ret, force_rel32 = True), '%s, $%s(%%rip)' % (self, self.operands[0].offs)) 687 688 def _raw_normal_instr(self) -> str: 689 ops = self._encode_operands(False, 0) 690 ret = self._instr(*list(ops)[::-1]) 691 692 # encode all instructions as raw bytes 693 if not self.is_branch_label: 694 return self._encode_ins(ret) 695 else: 696 return self._encode_ins(ret, force_rel32 = True) 697 698 699 @property 700 def size(self) -> int: 701 return self.encoded_size(0) 702 703 @functools.cached_property 704 def encoded(self) -> str: 705 if self.is_branch: 706 return self._encode_branch_instr() 707 else: 708 return self._encode_normal_instr() 709 710 def raw(self) -> bytes: 711 if self.is_branch: 712 return self._raw_branch_instr() 713 else: 714 return self._raw_normal_instr() 715 716 @functools.cached_property 717 def is_return(self) -> bool: 718 return self._instr is x86_64.RET 719 720 @functools.cached_property 721 def is_invoke(self) -> bool: 722 return self._instr is x86_64.CALL 723 724 @functools.cached_property 725 def is_branch(self) -> bool: 726 try: 727 return self.is_invoke or issubclass(self._instr, BranchInstruction) 728 except TypeError: 729 return False 730 731 @functools.cached_property 732 def is_jmp(self) -> bool: 733 return self._instr is x86_64.JMP 734 735 @functools.cached_property 736 def is_jmpq(self) -> bool: 737 return self.mnemonic == 'jmpq' 738 739 @property 740 def is_branch_label(self) -> bool: 741 return self.is_branch and isinstance(self.operands[0], Label) 742 743 def encoded_size(self, offset: int) -> int: 744 op = self._encode_operands(True, offset) 745 return len(self._encode_ins(self._instr(*list(op)[::-1]), force_rel32 = True)) 746 747 @classmethod 748 def parse(cls, line: str) -> 'Instruction': 749 lex = Tokenizer(line) 750 ntk = lex.next() 751 752 # the first token must be a name 753 if ntk.tag != TOKEN_NAME: 754 raise SyntaxError('mnemonic expected, got ' + repr(ntk)) 755 else: 756 return cls(ntk.val, cls._parse_operands(lex)) 757 758 @staticmethod 759 def encode(buf: bytes, comments: str = '') -> str: 760 i = 0 761 r = [] 762 n = len(buf) 763 764 # try "QUAD" first 765 while i < n - 7: 766 r.append('QUAD $0x%016x' % int.from_bytes(buf[i:i + 8], 'little')) 767 i += 8 768 769 # then "LONG" 770 while i < n - 3: 771 r.append('LONG $0x%08x' % int.from_bytes(buf[i:i + 4], 'little')) 772 i += 4 773 774 # then "SHORT" 775 while i < n - 1: 776 r.append('WORD $0x%04x' % int.from_bytes(buf[i:i + 2], 'little')) 777 i += 2 778 779 # then "BYTE" 780 while i < n: 781 r.append('BYTE $0x%02x' % buf[i]) 782 i += 1 783 784 # join them together, and attach the comment if any 785 if not comments: 786 return '; '.join(r) 787 else: 788 return '%s // %s' % ('; '.join(r), comments) 789 790 Reg = Optional[Register] 791 Disp = Optional[Displacement] 792 793 @classmethod 794 def _parse_mend(cls, ntk: Token, base: Reg, index: Register, scale: int, disp: Disp) -> Operand: 795 if ntk.tag != TOKEN_PUNC or ntk.val != ')': 796 raise SyntaxError('")" expected, got ' + repr(ntk)) 797 else: 798 return Memory(base, disp, Index(index, scale)) 799 800 @classmethod 801 def _parse_base(cls, lex: Tokenizer, ntk: Token, disp: Disp) -> Operand: 802 if ntk.tag == TOKEN_REG: 803 return cls._parse_idelim(lex, lex.next(), Register(ntk.val), disp) 804 elif ntk.tag == TOKEN_PUNC and ntk.val == ',': 805 return cls._parse_ibase(lex, lex.next(), None, disp) 806 else: 807 raise SyntaxError('register expected, got ' + repr(ntk)) 808 809 @classmethod 810 def _parse_ibase(cls, lex: Tokenizer, ntk: Token, base: Reg, disp: Disp) -> Operand: 811 if ntk.tag != TOKEN_REG: 812 raise SyntaxError('register expected, got ' + repr(ntk)) 813 else: 814 return cls._parse_sdelim(lex, lex.next(), base, Register(ntk.val), disp) 815 816 @classmethod 817 def _parse_idelim(cls, lex: Tokenizer, ntk: Token, base: Reg, disp: Disp) -> Operand: 818 if ntk.tag == TOKEN_END: 819 raise SyntaxError('unexpected EOF when parsing memory operands') 820 elif ntk.tag == TOKEN_PUNC and ntk.val == ')': 821 return Memory(base, disp, None) 822 elif ntk.tag == TOKEN_PUNC and ntk.val == ',': 823 return cls._parse_ibase(lex, lex.next(), base, disp) 824 else: 825 raise SyntaxError('"," or ")" expected, got ' + repr(ntk)) 826 827 @classmethod 828 def _parse_iscale(cls, lex: Tokenizer, ntk: Token, base: Reg, index: Register, disp: Disp) -> Operand: 829 if ntk.tag != TOKEN_NUM: 830 raise SyntaxError('integer expected, got ' + repr(ntk)) 831 elif ntk.val not in (1, 2, 4, 8): 832 raise SyntaxError('indexing scale can only be 1, 2, 4 or 8') 833 else: 834 return cls._parse_mend(lex.next(), base, index, ntk.val, disp) 835 836 @classmethod 837 def _parse_sdelim(cls, lex: Tokenizer, ntk: Token, base: Reg, index: Register, disp: Disp) -> Operand: 838 if ntk.tag == TOKEN_END: 839 raise SyntaxError('unexpected EOF when parsing memory operands') 840 elif ntk.tag == TOKEN_PUNC and ntk.val == ')': 841 return Memory(base, disp, Index(index)) 842 elif ntk.tag == TOKEN_PUNC and ntk.val == ',': 843 return cls._parse_iscale(lex, lex.next(), base, index, disp) 844 else: 845 raise SyntaxError('"," or ")" expected, got ' + repr(ntk)) 846 847 @classmethod 848 def _parse_refmem(cls, lex: Tokenizer, ntk: Token, ref: str) -> Operand: 849 if ntk.tag == TOKEN_END: 850 return Label(ref) 851 elif ntk.tag == TOKEN_PUNC and ntk.val == '(' : 852 return cls._parse_memory(lex, ntk, Reference(ref, 0)) 853 elif ntk.tag == TOKEN_NUM: 854 ntk = lex.next() 855 if ntk.tag == TOKEN_PUNC and ntk.val == '(': 856 return cls._parse_refmem(lex, ntk, Reference(ref, ntk.val)) 857 858 raise SyntaxError(f'identifier "{ref}" must either be a label or a displacement reference') 859 860 @classmethod 861 def _parse_memory(cls, lex: Tokenizer, ntk: Token, disp: Optional[Displacement]) -> Operand: 862 if ntk.tag != TOKEN_PUNC or ntk.val != '(': 863 raise SyntaxError('"(" expected, got ' + repr(ntk)) 864 else: 865 return cls._parse_base(lex, lex.next(), disp) 866 867 @classmethod 868 def _parse_operand(cls, lex: Tokenizer, ntk: Token, can_indir: bool = True) -> Operand: 869 if ntk.tag == TOKEN_REG: 870 return Register(ntk.val) 871 elif ntk.tag == TOKEN_IMM: 872 return Immediate(ntk.val) 873 elif ntk.tag == TOKEN_NUM: 874 return cls._parse_memory(lex, lex.next(), Immediate(ntk.val)) 875 elif ntk.tag == TOKEN_NAME: 876 return cls._parse_refmem(lex, lex.next(), ntk.val) 877 elif ntk.tag == TOKEN_PUNC and ntk.val == '(': 878 return cls._parse_memory(lex, ntk, None) 879 elif ntk.tag == TOKEN_PUNC and ntk.val == '*' and can_indir: 880 return cls._parse_operand(lex, lex.next(), False) 881 else: 882 raise SyntaxError('invalid token: ' + repr(ntk)) 883 884 @classmethod 885 def _parse_operands(cls, lex: Tokenizer) -> List[Operand]: 886 ret = [] 887 ntk = lex.next() 888 889 # check for empty operand 890 if ntk.tag == TOKEN_END: 891 return [] 892 893 # parse every operand 894 while True: 895 ret.append(cls._parse_operand(lex, ntk)) 896 ntk = lex.next() 897 898 # check for the ',' delimiter or the end of input 899 if ntk.tag == TOKEN_PUNC and ntk.val == ',': 900 ntk = lex.next() 901 elif ntk.tag != TOKEN_END: 902 raise SyntaxError('"," expected, got ' + repr(ntk)) 903 else: 904 return ret 905 906 ### Prototype Parser ### 907 908 ARGS_ORDER_C = [ 909 Register('rdi'), 910 Register('rsi'), 911 Register('rdx'), 912 Register('rcx'), 913 Register('r8'), 914 Register('r9'), 915 ] 916 917 ARGS_ORDER_GO = [ 918 Register('rax'), 919 Register('rbx'), 920 Register('rcx'), 921 Register('rdi'), 922 Register('rsi'), 923 Register('r8'), 924 ] 925 926 FPARGS_ORDER = [ 927 Register('xmm0'), 928 Register('xmm1'), 929 Register('xmm2'), 930 Register('xmm3'), 931 Register('xmm4'), 932 Register('xmm5'), 933 Register('xmm6'), 934 Register('xmm7'), 935 ] 936 937 class Parameter: 938 name : str 939 size : int 940 creg : Register 941 goreg: Register 942 943 def __init__(self, name: str, size: int, reg: Register, goreg: Register): 944 self.creg = reg 945 self.goreg = reg 946 self.name = name 947 self.size = size 948 949 def __repr__(self): 950 return '<ARG %s(%d): %s>' % (self.name, self.size, self.creg) 951 952 class Pcsp: 953 entry: int 954 maxpc: int 955 out : List[Tuple[int, int]] 956 pc : int 957 sp : int 958 959 def __init__(self, entry: int): 960 self.out = [] 961 self.maxpc = entry 962 self.entry = entry 963 self.pc = entry 964 self.sp = 0 965 966 def __str__(self) -> str: 967 ret = '[][2]uint32{\n' 968 for pc, sp in self.out: 969 ret += ' {%d, %d},\n' % (pc, sp) 970 return ret + ' }' 971 972 def optimize(self): 973 # push the last record 974 self.out.append((self.pc - self.entry, self.sp)) 975 # sort by pc 976 self.out.sort(key=lambda x: x[0]) 977 # NOTICE: first pair {1, 0} to be compitable with golang 978 tmp = [(1, 0)] 979 lpc, lsp = 0, -1 980 for pc, sp in self.out: 981 # sp changed, push new record 982 if pc != lpc and sp != lsp: 983 tmp.append((pc, sp)) 984 # sp unchanged, replace with the higher pc 985 if pc != lpc and sp == lsp: 986 if len(tmp) > 0: 987 tmp.pop(-1) 988 tmp.append((pc, sp)) 989 990 lpc, lsp = pc, sp 991 self.out = tmp 992 993 def update(self, dpc: int, dsp: int): 994 self.out.append((self.pc - self.entry, self.sp)) 995 self.pc += dpc 996 self.sp += dsp 997 if self.pc > self.maxpc: 998 self.maxpc = self.pc 999 1000 class Prototype: 1001 args: List[Parameter] 1002 retv: Optional[Parameter] 1003 1004 def __init__(self, retv: Optional[Parameter], args: List[Parameter]): 1005 self.retv = retv 1006 self.args = args 1007 1008 def __repr__(self): 1009 if self.retv is None: 1010 return '<PROTO (%s)>' % repr(self.args) 1011 else: 1012 return '<PROTO (%r) -> %r>' % (self.args, self.retv) 1013 1014 @property 1015 def argspace(self) -> int: 1016 return sum( 1017 [v.size for v in self.args], 1018 (0 if self.retv is None else self.retv.size) 1019 ) 1020 1021 class PrototypeMap(Dict[str, Prototype]): 1022 @staticmethod 1023 def _dv(c: str) -> int: 1024 if c == '(': 1025 return 1 1026 elif c == ')': 1027 return -1 1028 else: 1029 return 0 1030 1031 @staticmethod 1032 def _tk(s: str, p: str) -> bool: 1033 return s.startswith(p) and (s == p or s[len(p)].isspace()) 1034 1035 @classmethod 1036 def _punc(cls, s: str) -> bool: 1037 return s in cls.__puncs_ 1038 1039 @staticmethod 1040 def _err(msg: str) -> SyntaxError: 1041 return SyntaxError( 1042 msg + ', ' + 1043 'the parser integrated in this tool is just a text-based parser, ' + 1044 'so please keep the companion .go file as simple as possible and do not use defined types' 1045 ) 1046 1047 @staticmethod 1048 def _align(nb: int) -> int: 1049 return (((nb - 1) >> 3) + 1) << 3 1050 1051 @classmethod 1052 def _retv(cls, ret: str) -> Tuple[str, int, Register, Register]: 1053 name, size, xmm = cls._args(ret) 1054 reg = Register('xmm0') if xmm else Register('rax') 1055 return name, size, reg, reg 1056 1057 @classmethod 1058 def _args(cls, arg: str, sv: str = '') -> Tuple[str, int, bool]: 1059 while True: 1060 if not arg: 1061 raise SyntaxError('missing type for parameter: ' + sv) 1062 elif arg[0] != '_' and not arg[0].isalnum(): 1063 return (sv,) + cls._size(arg.strip()) 1064 elif not sv and arg[0].isdigit(): 1065 raise SyntaxError('invalid character: ' + repr(arg[0])) 1066 else: 1067 sv += arg[0] 1068 arg = arg[1:] 1069 1070 @classmethod 1071 def _size(cls, name: str) -> Tuple[int, bool]: 1072 if name[0] == '*': 1073 return cls._align(8), False 1074 elif name in ('int8', 'uint8', 'byte', 'bool'): 1075 return cls._align(1), False 1076 elif name in ('int16', 'uint16'): 1077 return cls._align(2), False 1078 elif name == 'float32': 1079 return cls._align(4), True 1080 elif name in ('int32', 'uint32', 'rune'): 1081 return cls._align(4), False 1082 elif name == 'float64': 1083 return cls._align(8), True 1084 elif name in ('int64', 'uint64', 'uintptr', 'int', 'Pointer', 'unsafe.Pointer'): 1085 return cls._align(8), False 1086 else: 1087 raise cls._err('unrecognized type "%s"' % name) 1088 1089 @classmethod 1090 def _func(cls, src: List[str], idx: int, depth: int = 0) -> Tuple[str, int]: 1091 for i in range(idx, len(src)): 1092 for x in map(cls._dv, src[i]): 1093 if depth + x >= 0: 1094 depth += x 1095 else: 1096 raise cls._err('encountered ")" more than "(" on line %d' % (i + 1)) 1097 else: 1098 if depth == 0: 1099 return ' '.join(src[idx:i + 1]), i + 1 1100 else: 1101 raise cls._err('unexpected EOF when parsing function signatures') 1102 1103 @classmethod 1104 def parse(cls, src: str) -> Tuple[str, 'PrototypeMap']: 1105 idx = 0 1106 pkg = '' 1107 ret = PrototypeMap() 1108 buf = src.splitlines() 1109 1110 # scan through all the lines 1111 while idx < len(buf): 1112 line = buf[idx] 1113 line = line.strip() 1114 1115 # skip empty lines 1116 if not line: 1117 idx += 1 1118 continue 1119 1120 # check for package name 1121 if cls._tk(line, 'package'): 1122 idx, pkg = idx + 1, line[7:].strip().split()[0] 1123 continue 1124 1125 if OUTPUT_RAW: 1126 1127 # extract funcname like "[var ]{funcname} = func(..." 1128 end = line.find('func(') 1129 if end == -1: 1130 idx += 1 1131 continue 1132 name = line[:end].strip() 1133 if name.startswith('var '): 1134 name = name[4:].strip() 1135 1136 # function names must be identifiers 1137 if not name.isidentifier(): 1138 raise cls._err('invalid function prototype: ' + name) 1139 1140 # register a empty prototype 1141 ret[name] = Prototype(None, []) 1142 idx += 1 1143 1144 else: 1145 1146 # only cares about those functions that does not have bodies 1147 if line[-1] == '{' or not cls._tk(line, 'func'): 1148 idx += 1 1149 continue 1150 1151 # prevent type-aliasing primitive types into other names 1152 if cls._tk(line, 'type'): 1153 raise cls._err('please do not declare any type with in the companion .go file') 1154 1155 # find the next function declaration 1156 decl, pos = cls._func(buf, idx) 1157 func, idx = decl[4:].strip(), pos 1158 1159 # find the beginning '(' 1160 nd = 1 1161 pos = func.find('(') 1162 1163 # must have a '(' 1164 if pos == -1: 1165 raise cls._err('invalid function prototype: ' + decl) 1166 1167 # extract the name and signature 1168 args = '' 1169 name = func[:pos].strip() 1170 func = func[pos + 1:].strip() 1171 1172 # skip the method declaration 1173 if not name: 1174 continue 1175 1176 # function names must be identifiers 1177 if not name.isidentifier(): 1178 raise cls._err('invalid function prototype: ' + decl) 1179 1180 # extract the argument list 1181 while nd and func: 1182 nch = func[0] 1183 func = func[1:] 1184 1185 # adjust the nesting level 1186 nd += cls._dv(nch) 1187 args += nch 1188 1189 # check for EOF 1190 if not nd: 1191 func = func.strip() 1192 else: 1193 raise cls._err('unexpected EOF when parsing function prototype: ' + decl) 1194 1195 # check for multiple returns 1196 if ',' in func: 1197 raise cls._err('can only return a single value (detected by looking for "," within the return list)') 1198 1199 # check for return signature 1200 if not func: 1201 retv = None 1202 elif func[0] == '(' and func[-1] == ')': 1203 retv = Parameter(*cls._retv(func[1:-1])) 1204 else: 1205 raise SyntaxError('badly formatted return argument (please use parenthesis and proper arguments naming): ' + func) 1206 1207 # extract the argument list 1208 if not args[:-1]: 1209 args, alens, axmm = [], [], [] 1210 else: 1211 args, alens, axmm = list(zip(*[cls._args(v.strip()) for v in args[:-1].split(',')])) 1212 1213 # check for the result 1214 cregs = [] 1215 goregs = [] 1216 idxs = [0, 0] 1217 1218 # split the integer & floating point registers 1219 for xmm in axmm: 1220 key = 0 if xmm else 1 1221 seq = FPARGS_ORDER if xmm else ARGS_ORDER_C 1222 goseq = FPARGS_ORDER if xmm else ARGS_ORDER_GO 1223 1224 # check the argument count 1225 if idxs[key] >= len(seq): 1226 raise cls._err("too many arguments, consider pack some into a pointer") 1227 1228 # add the register 1229 cregs.append(seq[idxs[key]]) 1230 goregs.append(goseq[idxs[key]]) 1231 idxs[key] += 1 1232 1233 # register the prototype 1234 ret[name] = Prototype(retv, [ 1235 Parameter(arg, size, creg, goreg) 1236 for arg, size, creg, goreg in zip(args, alens, cregs, goregs) 1237 ]) 1238 1239 # all done 1240 return pkg, ret 1241 1242 ### Assembly Source Parser ### 1243 1244 ESC_IDLE = 0 # escape parser is idleing 1245 ESC_ISTR = 1 # currently inside a string 1246 ESC_BKSL = 2 # encountered backslash, prepare for escape sequences 1247 ESC_HEX0 = 3 # expect the first hexadecimal character of a "\x" escape 1248 ESC_HEX1 = 4 # expect the second hexadecimal character of a "\x" escape 1249 ESC_OCT1 = 5 # expect the second octal character of a "\000" escape 1250 ESC_OCT2 = 6 # expect the third octal character of a "\000" escape 1251 1252 class Command: 1253 cmd : str 1254 args : List[Union[str, bytes]] 1255 1256 def __init__(self, cmd: str, args: List[Union[str, bytes]]): 1257 self.cmd = cmd 1258 self.args = args 1259 1260 def __repr__(self): 1261 return '<CMD %s %s>' % (self.cmd, ', '.join(map(repr, self.args))) 1262 1263 @classmethod 1264 def parse(cls, src: str) -> 'Command': 1265 val = src.split(None, 1) 1266 cmd = val[0] 1267 1268 # no parameters 1269 if len(val) == 1: 1270 return cls(cmd, []) 1271 1272 # extract the argument string 1273 idx = 0 1274 esc = 0 1275 pos = None 1276 args = [] 1277 vstr = val[1] 1278 1279 # scan through the whole string 1280 while idx < len(vstr): 1281 nch = vstr[idx] 1282 idx += 1 1283 1284 # mark the start of the argument 1285 if pos is None: 1286 pos = idx - 1 1287 1288 # encountered the delimiter outside of a string 1289 if nch == ',' and esc == ESC_IDLE: 1290 pos, p = None, pos 1291 args.append(vstr[p:idx - 1].strip()) 1292 1293 # start of a string 1294 elif nch == '"' and esc == ESC_IDLE: 1295 esc = ESC_ISTR 1296 1297 # end of string 1298 elif nch == '"' and esc == ESC_ISTR: 1299 esc = ESC_IDLE 1300 pos, p = None, pos 1301 args.append(vstr[p:idx].strip()[1:-1].encode('utf-8').decode('unicode_escape')) 1302 1303 # escape characters 1304 elif nch == '\\' and esc == ESC_ISTR: 1305 esc = ESC_BKSL 1306 1307 # hexadecimal escape characters (3 chars) 1308 elif esc == ESC_BKSL and nch == 'x': 1309 esc = ESC_HEX0 1310 1311 # octal escape characters (3 chars) 1312 elif esc == ESC_BKSL and nch in string.octdigits: 1313 esc = ESC_OCT1 1314 1315 # generic escape characters (single char) 1316 elif esc == ESC_BKSL and nch in ('a', 'b', 'f', 'r', 'n', 't', 'v', '"', '\\'): 1317 esc = ESC_ISTR 1318 1319 # invalid escape sequence 1320 elif esc == ESC_BKSL: 1321 raise SyntaxError('invalid escape character: ' + repr(nch)) 1322 1323 # normal characters, simply advance to the next character 1324 elif esc in (ESC_IDLE, ESC_ISTR): 1325 pass 1326 1327 # hexadecimal escape characters 1328 elif esc in (ESC_HEX0, ESC_HEX1) and nch.lower() in string.hexdigits: 1329 esc = ESC_HEX1 if esc == ESC_HEX0 else ESC_ISTR 1330 1331 # invalid hexadecimal character 1332 elif esc in (ESC_HEX0, ESC_HEX1): 1333 raise SyntaxError('invalid hexdecimal character: ' + repr(nch)) 1334 1335 # octal escape characters 1336 elif esc in (ESC_OCT1, ESC_OCT2) and nch.lower() in string.octdigits: 1337 esc = ESC_OCT2 if esc == ESC_OCT1 else ESC_ISTR 1338 1339 # at most 3 octal digits 1340 elif esc in (ESC_OCT1, ESC_OCT2): 1341 esc = ESC_ISTR 1342 1343 # illegal state, should not happen 1344 else: 1345 raise RuntimeError('illegal state: %d' % esc) 1346 1347 # check for the last argument 1348 if pos is None: 1349 return cls(cmd, args) 1350 1351 # add the last argument and build the command 1352 args.append(vstr[pos:].strip()) 1353 return cls(cmd, args) 1354 1355 class Expression: 1356 pos: int 1357 src: str 1358 1359 def __init__(self, src: str): 1360 self.pos = 0 1361 self.src = src 1362 1363 @property 1364 def _ch(self) -> str: 1365 return self.src[self.pos] 1366 1367 @property 1368 def _eof(self) -> bool: 1369 return self.pos >= len(self.src) 1370 1371 def _rch(self) -> str: 1372 pos, self.pos = self.pos, self.pos + 1 1373 return self.src[pos] 1374 1375 def _hex(self, ch: str) -> bool: 1376 if len(ch) == 1 and ch[0] == '0': 1377 return self._ch.lower() == 'x' 1378 elif len(ch) <= 1 or ch[1].lower() != 'x': 1379 return self._ch.isdigit() 1380 else: 1381 return self._ch in string.hexdigits 1382 1383 def _int(self, ch: str) -> Token: 1384 while not self._eof and self._hex(ch): 1385 ch += self._rch() 1386 else: 1387 if ch.lower().startswith('0x'): 1388 return Token.num(int(ch, 16)) 1389 elif ch[0] == '0': 1390 return Token.num(int(ch, 8)) 1391 else: 1392 return Token.num(int(ch)) 1393 1394 def _name(self, ch: str) -> Token: 1395 while not self._eof and (self._ch == '_' or self._ch.isalnum()): 1396 ch += self._rch() 1397 else: 1398 return Token.name(ch) 1399 1400 def _read(self, ch: str) -> Token: 1401 if ch.isdigit(): 1402 return self._int(ch) 1403 elif ch.isidentifier(): 1404 return self._name(ch) 1405 elif ch in ('*', '<', '>') and not self._eof and self._ch == ch: 1406 return Token.punc(self._rch() * 2) 1407 elif ch in ('+', '-', '*', '/', '%', '&', '|', '^', '~', '(', ')'): 1408 return Token.punc(ch) 1409 else: 1410 raise SyntaxError('invalid character: ' + repr(ch)) 1411 1412 def _peek(self) -> Optional[Token]: 1413 pos = self.pos 1414 ret = self._next() 1415 self.pos = pos 1416 return ret 1417 1418 def _next(self) -> Optional[Token]: 1419 while not self._eof and self._ch.isspace(): 1420 self.pos += 1 1421 else: 1422 return Token.end() if self._eof else self._read(self._rch()) 1423 1424 def _grab(self, tk: Token, getvalue: Callable[[str], int]) -> int: 1425 if tk.tag == TOKEN_NUM: 1426 return tk.val 1427 elif tk.tag == TOKEN_NAME: 1428 return getvalue(tk.val) 1429 else: 1430 raise SyntaxError('integer or identifier expected, got ' + repr(tk)) 1431 1432 __pred__ = [ 1433 {'<<', '>>'}, 1434 {'|'}, 1435 {'^'}, 1436 {'&'}, 1437 {'+', '-'}, 1438 {'*', '/', '%'}, 1439 {'**'}, 1440 ] 1441 1442 __binary__ = { 1443 '+' : lambda a, b: a + b, 1444 '-' : lambda a, b: a - b, 1445 '*' : lambda a, b: a * b, 1446 '/' : lambda a, b: a / b, 1447 '%' : lambda a, b: a % b, 1448 '&' : lambda a, b: a & b, 1449 '^' : lambda a, b: a ^ b, 1450 '|' : lambda a, b: a | b, 1451 '<<' : lambda a, b: a << b, 1452 '>>' : lambda a, b: a >> b, 1453 '**' : lambda a, b: a ** b, 1454 } 1455 1456 def _eval(self, op: str, v1: int, v2: int) -> int: 1457 return self.__binary__[op](v1, v2) 1458 1459 def _nest(self, nest: int, getvalue: Callable[[str], int]) -> int: 1460 ret = self._expr(0, nest + 1, getvalue) 1461 ntk = self._next() 1462 1463 # it must follows with a ')' operator 1464 if ntk.tag != TOKEN_PUNC or ntk.val != ')': 1465 raise SyntaxError('")" expected, got ' + repr(ntk)) 1466 else: 1467 return ret 1468 1469 def _unit(self, nest: int, getvalue: Callable[[str], int]) -> int: 1470 tk = self._next() 1471 tt, tv = tk.tag, tk.val 1472 1473 # check for unary operators 1474 if tt == TOKEN_NUM: 1475 return tv 1476 elif tt == TOKEN_NAME: 1477 return getvalue(tv) 1478 elif tt == TOKEN_PUNC and tv == '(': 1479 return self._nest(nest, getvalue) 1480 elif tt == TOKEN_PUNC and tv == '+': 1481 return self._unit(nest, getvalue) 1482 elif tt == TOKEN_PUNC and tv == '-': 1483 return -self._unit(nest, getvalue) 1484 elif tt == TOKEN_PUNC and tv == '~': 1485 return ~self._unit(nest, getvalue) 1486 else: 1487 raise SyntaxError('integer, unary operator or nested expression expected, got ' + repr(tk)) 1488 1489 def _term(self, pred: int, nest: int, getvalue: Callable[[str], int]) -> int: 1490 lv = self._expr(pred + 1, nest, getvalue) 1491 tk = self._peek() 1492 1493 # scan to the end 1494 while True: 1495 tt = tk.tag 1496 tv = tk.val 1497 1498 # encountered EOF 1499 if tt == TOKEN_END: 1500 return lv 1501 1502 # must be an operator here 1503 if tt != TOKEN_PUNC: 1504 raise SyntaxError('operator expected, got ' + repr(tk)) 1505 1506 # check for the operator precedence 1507 if tv not in self.__pred__[pred]: 1508 return lv 1509 1510 # apply the operator 1511 op = self._next().val 1512 rv = self._expr(pred + 1, nest, getvalue) 1513 lv = self._eval(op, lv, rv) 1514 tk = self._peek() 1515 1516 def _expr(self, pred: int, nest: int, getvalue: Callable[[str], int]) -> int: 1517 if pred >= len(self.__pred__): 1518 return self._unit(nest, getvalue) 1519 else: 1520 return self._term(pred, nest, getvalue) 1521 1522 def eval(self, getvalue: Callable[[str], int]) -> int: 1523 return self._expr(0, 0, getvalue) 1524 1525 1526 class Instr: 1527 ALIGN_WIDTH = 48 1528 len : int = NotImplemented 1529 instr : Union[str, Instruction] = NotImplemented 1530 1531 def size(self, pc: int) -> int: 1532 return self.len 1533 1534 def formatted(self, pc: int) -> str: 1535 raise NotImplementedError 1536 1537 @staticmethod 1538 def raw_formatted(bs: bytes, comm: str, pc: int) -> str: 1539 t = '\t' 1540 if bs: 1541 for b in bs: 1542 t +='0x%02x, ' % b 1543 # if len(bs)<Instr.ALIGN_WIDTH: 1544 # t += '\b' * (Instr.ALIGN_WIDTH - len(bs)) 1545 return '%s//%s%s' % (t, ('0x%08x ' % pc) if pc else ' ', comm) 1546 class RawInstr(Instr): 1547 bs: bytes 1548 def __init__(self, size: int, instr: str, bs: bytes): 1549 self.len = size 1550 self.instr = instr 1551 self.bs = bs 1552 1553 def formatted(self, _: int) -> str: 1554 return '\t' + self.instr 1555 1556 def raw_formatted(self, pc: int) -> str: 1557 return Instr.raw_formatted(self.bs, self.instr, pc) 1558 1559 class IntInstr(Instr): 1560 comm: str 1561 func: Callable[[], int] 1562 1563 def __init__(self, size: int, func: Callable[[], int], comments: str = ''): 1564 self.len = size 1565 self.func = func 1566 self.comm = comments 1567 1568 @property 1569 def instr(self) -> str: 1570 return Instruction.encode(self.func().to_bytes(self.len, 'little'), self.comm) 1571 1572 def formatted(self, _: int) -> str: 1573 return '\t' + self.instr 1574 1575 def raw_formatted(self, pc: int) -> str: 1576 return Instr.raw_formatted(self.func().to_bytes(self.len, 'little'), self.comm, pc) 1577 1578 class X86Instr(Instr): 1579 def __init__(self, instr: Instruction): 1580 self.len = instr.size 1581 self.instr = instr 1582 1583 def resize(self, size: int) -> int: 1584 self.len = size 1585 return size 1586 1587 def formatted(self, _: int) -> str: 1588 return '\t' + str(self.instr.encoded) 1589 1590 def raw_formatted(self, pc: int) -> str: 1591 return Instr.raw_formatted(self.instr._raw_normal_instr(), str(self.instr), pc) 1592 1593 class LabelInstr(Instr): 1594 def __init__(self, name: str): 1595 self.len = 0 1596 self.instr = name 1597 1598 def formatted(self, _: int) -> str: 1599 if self.instr.isidentifier(): 1600 return self.instr + ':' 1601 else: 1602 return '_LB_%08x: // %s' % (hash(self.instr) & 0xffffffff, self.instr) 1603 1604 def raw_formatted(self, pc: int) -> str: 1605 return Instr.raw_formatted(None, str(self.instr), pc) 1606 1607 class BranchInstr(Instr): 1608 def __init__(self, instr: Instruction): 1609 self.len = instr.size 1610 self.instr = instr 1611 1612 def formatted(self, _: int) -> str: 1613 return '\t' + self.instr.encoded 1614 1615 def raw_formatted(self, pc: int) -> str: 1616 return Instr.raw_formatted(self.instr._raw_branch_instr(), str(self.instr), pc) 1617 1618 class CommentInstr(Instr): 1619 def __init__(self, text: str): 1620 self.len = 0 1621 self.instr = '// ' + text 1622 1623 def formatted(self, _: int) -> str: 1624 return '\t' + self.instr 1625 1626 def raw_formatted(self, pc: int) -> str: 1627 return Instr.raw_formatted(None, str(self.instr), None) 1628 1629 class AlignmentInstr(Instr): 1630 bits: int 1631 fill: int 1632 1633 def __init__(self, bits: int, fill: int = 0): 1634 self.bits = bits 1635 self.fill = fill 1636 1637 def size(self, pc: int) -> int: 1638 mask = (1 << self.bits) - 1 1639 return (mask - (pc & mask) + 1) & mask 1640 1641 def formatted(self, pc: int) -> str: 1642 buf = bytes([self.fill]) * self.size(pc) 1643 return '\t' + Instruction.encode(buf, '.p2align %d, 0x%02x' % (self.bits, self.fill)) 1644 1645 def raw_formatted(self, pc: int) -> str: 1646 buf = bytes([self.fill]) * self.size(pc) 1647 return Instr.raw_formatted(buf, '.p2align %d, 0x%02x' % (self.bits, self.fill), pc) 1648 1649 REG_MAP = { 1650 'rax' : ('MOVQ' , 'AX'), 1651 'rdi' : ('MOVQ' , 'DI'), 1652 'rsi' : ('MOVQ' , 'SI'), 1653 'rdx' : ('MOVQ' , 'DX'), 1654 'rcx' : ('MOVQ' , 'CX'), 1655 'r8' : ('MOVQ' , 'R8'), 1656 'r9' : ('MOVQ' , 'R9'), 1657 'xmm0' : ('MOVSD' , 'X0'), 1658 'xmm1' : ('MOVSD' , 'X1'), 1659 'xmm2' : ('MOVSD' , 'X2'), 1660 'xmm3' : ('MOVSD' , 'X3'), 1661 'xmm4' : ('MOVSD' , 'X4'), 1662 'xmm5' : ('MOVSD' , 'X5'), 1663 'xmm6' : ('MOVSD' , 'X6'), 1664 'xmm7' : ('MOVSD' , 'X7'), 1665 } 1666 1667 class Counter: 1668 value: int = 0 1669 1670 @classmethod 1671 def next(cls) -> int: 1672 val, cls.value = cls.value, cls.value + 1 1673 return val 1674 1675 class BasicBlock: 1676 maxsp: int 1677 name: str 1678 weak: bool 1679 jmptab: bool 1680 func: bool 1681 body: List[Instr] 1682 prevs: List['BasicBlock'] 1683 next: Optional['BasicBlock'] 1684 jump: Optional['BasicBlock'] 1685 1686 def __init__(self, name: str, weak: bool = True, jmptab: bool = False, func: bool = False): 1687 self.maxsp = -1 1688 self.body = [] 1689 self.prevs = [] 1690 self.name = name 1691 self.weak = weak 1692 self.next = None 1693 self.jump = None 1694 self.jmptab = jmptab 1695 self.func = func 1696 1697 def __repr__(self): 1698 return '{BasicBlock %s}' % repr(self.name) 1699 1700 @property 1701 def last(self) -> Optional[Instr]: 1702 return next((v for v in reversed(self.body) if not isinstance(v, CommentInstr)), None) 1703 1704 def size_of(self, pc: int) -> int: 1705 return functools.reduce(lambda p, v: p + v.size(pc + p), self.body, 0) 1706 1707 def link_to(self, block: 'BasicBlock'): 1708 self.next = block 1709 block.prevs.append(self) 1710 1711 def jump_to(self, block: 'BasicBlock'): 1712 self.jump = block 1713 block.prevs.append(self) 1714 1715 @classmethod 1716 def annonymous(cls) -> 'BasicBlock': 1717 return cls('// bb.%d' % Counter.next(), weak = False) 1718 1719 CLANG_JUMPTABLE_LABLE = 'LJTI' 1720 1721 class CodeSection: 1722 dead : bool 1723 export : bool 1724 blocks : List[BasicBlock] 1725 labels : Dict[str, BasicBlock] 1726 jmptabs: Dict[str, List[BasicBlock]] 1727 funcs : Dict[str, Pcsp] 1728 1729 def __init__(self): 1730 self.dead = False 1731 self.labels = {} 1732 self.export = False 1733 self.blocks = [BasicBlock.annonymous()] 1734 self.jmptabs = {} 1735 self.funcs = {} 1736 1737 @classmethod 1738 def _dfs_jump_first(cls, bb: BasicBlock, visited: Dict[BasicBlock, bool], hook: Callable[[BasicBlock], bool]) -> bool: 1739 if bb not in visited or not visited[bb]: 1740 visited[bb] = True 1741 if bb.jump and not cls._dfs_jump_first(bb.jump, visited, hook): 1742 return False 1743 if bb.next and not cls._dfs_jump_first(bb.next, visited, hook): 1744 return False 1745 return hook(bb) 1746 else: 1747 return True 1748 1749 def get_jmptab(self, name: str) -> List[BasicBlock]: 1750 return self.jmptabs.setdefault(name, []) 1751 1752 def get_block(self, name: str) -> BasicBlock: 1753 for block in self.blocks: 1754 if block.name == name: 1755 return block 1756 1757 @property 1758 def block(self) -> BasicBlock: 1759 return self.blocks[-1] 1760 1761 @property 1762 def instrs(self) -> Iterable[Instr]: 1763 for block in self.blocks: 1764 yield from block.body 1765 1766 def _make(self, name: str, jmptab: bool = False, func: bool = False): 1767 if func: 1768 #NOTICE: if it is a function, always set func to be True 1769 if (old := self.labels.get(name)) and (old.func != func): 1770 old.func = True 1771 return self.labels.setdefault(name, BasicBlock(name, jmptab = jmptab, func = func)) 1772 1773 def _next(self, link: BasicBlock): 1774 if self.dead: 1775 self.dead = False 1776 else: 1777 self.block.link_to(link) 1778 1779 def _decl(self, name: str, block: BasicBlock): 1780 block.weak = False 1781 block.body.append(LabelInstr(name)) 1782 self._next(block) 1783 self.blocks.append(block) 1784 1785 def _kill(self, name: str): 1786 self.dead = True 1787 self.block.link_to(self._make(name)) 1788 1789 def _split(self, jmp: BasicBlock): 1790 self.jump = True 1791 link = BasicBlock.annonymous() 1792 self.labels[link.name] = link 1793 self.block.link_to(link) 1794 self.block.jump_to(jmp) 1795 self.blocks.append(link) 1796 1797 @staticmethod 1798 def _mk_align(v: int) -> int: 1799 if v & 7 == 0: 1800 return v 1801 else: 1802 print('* warning: SP is not aligned with 8 bytes.', file = sys.stderr) 1803 return (v + 7) & -8 1804 1805 @staticmethod 1806 def _is_spadj(ins: Instruction) -> bool: 1807 return len(ins.operands) == 2 and \ 1808 isinstance(ins.operands[0], Immediate) and \ 1809 isinstance(ins.operands[1], Register) and \ 1810 ins.operands[1].reg == 'rsp' 1811 1812 @staticmethod 1813 def _is_spmove(ins: Instruction, i: int) -> bool: 1814 return len(ins.operands) == 2 and \ 1815 isinstance(ins.operands[0], Register) and \ 1816 isinstance(ins.operands[1], Register) and \ 1817 ins.operands[i].reg == 'rsp' 1818 1819 @staticmethod 1820 def _is_rjump(ins: Optional[Instr]) -> bool: 1821 return isinstance(ins, X86Instr) and ins.instr.is_branch_label 1822 1823 def _find_label(self, name: str, adjs: Iterable[int], size: int = 0) -> int: 1824 for adj, block in zip(adjs, self.blocks): 1825 if block.name == name: 1826 return size 1827 else: 1828 size += block.size_of(size) + adj 1829 else: 1830 raise SyntaxError('unresolved reference to name: ' + name) 1831 1832 def _alloc_instr(self, instr: Instruction): 1833 if not instr.is_branch_label: 1834 self.block.body.append(X86Instr(instr)) 1835 else: 1836 self.block.body.append(BranchInstr(instr)) 1837 1838 # it seems to not be able to specify stack aligment inside the Go ASM so we 1839 # need to replace the aligned instructions with unaligned one if either of it's 1840 # operand is an RBP relative addressing memory operand 1841 1842 __instr_repl__ = { 1843 'movdqa' : 'movdqu', 1844 'movaps' : 'movups', 1845 'vmovdqa' : 'vmovdqu', 1846 'vmovaps' : 'vmovups', 1847 'vmovapd' : 'vmovupd', 1848 } 1849 1850 def _check_align(self, instr: Instruction) -> bool: 1851 if instr.mnemonic in self.__instr_repl__: 1852 # NOTICE: since we need use unaligned instruction, thus SP can be fixed according to PC 1853 for op in instr.operands: 1854 if isinstance(op, Memory): 1855 if op.base is not None and (op.base.reg == 'rbp' or op.base.reg == 'rsp'): 1856 instr.mnemonic = self.__instr_repl__[instr.mnemonic] 1857 return False 1858 elif instr.mnemonic == 'andq' and self._is_spadj(instr): 1859 # NOTICE: since we always use unaligned instruction above, we don't need align SP 1860 return True 1861 1862 def _check_split(self, instr: Instruction): 1863 if instr.is_return: 1864 self.dead = True 1865 1866 elif instr.is_jmpq: # jmpq 1867 # backtrace jump table from current block (BFS) 1868 prevs = [self.block] 1869 visited = set() 1870 while len(prevs) > 0: 1871 curb = prevs.pop() 1872 if curb in visited: 1873 continue 1874 else: 1875 visited.add(curb) 1876 1877 # backtrace instructions 1878 for ins in reversed(curb.body): 1879 if isinstance(ins, X86Instr) and ins.instr.jmptab: 1880 self._split(self._make(ins.instr.jmptab, jmptab = True)) 1881 return 1882 1883 if curb.prevs: 1884 prevs.extend(curb.prevs) 1885 1886 elif instr.is_branch_label: 1887 if instr.is_jmp: # jmp 1888 self._kill(instr.operands[0].name) 1889 1890 elif instr.is_invoke: # call 1891 fname = instr.operands[0].name 1892 self._split(self._make(fname, func = True)) 1893 1894 else: # jeq, ja, jae ... 1895 self._split(self._make(instr.operands[0].name)) 1896 1897 def _trace_block(self, bb: BasicBlock, pcsp: Optional[Pcsp]) -> int: 1898 if (pcsp is not None): 1899 if bb.name in self.funcs: 1900 # already traced 1901 pcsp = None 1902 else: 1903 # continue tracing, update the pcsp 1904 # NOTICE: must mark pcsp at block entry because go only calculate delta value 1905 pcsp.pc = self.get(bb.name) 1906 if bb.func or pcsp.pc < pcsp.entry: 1907 # new func 1908 pcsp = Pcsp(pcsp.pc) 1909 self.funcs[bb.name] = pcsp 1910 1911 if bb.maxsp == -1: 1912 ret = self._trace_nocache(bb, pcsp) 1913 return ret 1914 elif bb.maxsp >= 0: 1915 return bb.maxsp 1916 else: 1917 return 0 1918 1919 def _trace_nocache(self, bb: BasicBlock, pcsp: Optional[Pcsp]) -> int: 1920 bb.maxsp = -2 1921 1922 # ## FIXME: 1923 # if pcsp is None: 1924 # pcsp = Pcsp(0) 1925 1926 # make a fake object just for reducing redundant checking 1927 if pcsp: 1928 pc0, sp0 = pcsp.pc, pcsp.sp 1929 1930 maxsp, term = self._trace_instructions(bb, pcsp) 1931 1932 # this is a terminating block 1933 if term: 1934 return maxsp 1935 1936 # don't trace it's next block if it's an unconditional jump 1937 a, b = 0, 0 1938 if pcsp: 1939 pc, sp = pcsp.pc, pcsp.sp 1940 1941 if bb.jump: 1942 if bb.jump.jmptab: 1943 cases = self.get_jmptab(bb.jump.name) 1944 for case in cases: 1945 nsp = self._trace_block(case, pcsp) 1946 if pcsp: 1947 pcsp.pc, pcsp.sp = pc, sp 1948 if nsp > a: 1949 a = nsp 1950 else: 1951 a = self._trace_block(bb.jump, pcsp) 1952 if pcsp: 1953 pcsp.pc, pcsp.sp = pc, sp 1954 1955 if bb.next: 1956 b = self._trace_block(bb.next, pcsp) 1957 1958 if pcsp: 1959 pcsp.pc, pcsp.sp = pc0, sp0 1960 1961 # select the maximum stack depth 1962 bb.maxsp = maxsp + max(a, b) 1963 return bb.maxsp 1964 1965 def _trace_instructions(self, bb: BasicBlock, pcsp: Pcsp) -> Tuple[int, bool]: 1966 cursp = 0 1967 maxsp = 0 1968 close = False 1969 1970 # scan every instruction 1971 for ins in bb.body: 1972 diff = 0 1973 1974 if isinstance(ins, X86Instr): 1975 name = ins.instr.mnemonic 1976 args = ins.instr.operands 1977 1978 # check for instructions 1979 if name == 'retq': 1980 close = True 1981 elif name == 'popq': 1982 diff = -8 1983 elif name == 'pushq': 1984 diff = 8 1985 elif name == 'addq' and self._is_spadj(ins.instr): 1986 diff = -self._mk_align(args[0].val) 1987 elif name == 'subq' and self._is_spadj(ins.instr): 1988 diff = self._mk_align(args[0].val) 1989 1990 # FIXME: andq is usually used for aligment of memory address, we can't handle it correctly now 1991 # elif name == 'andq' and self._is_spadj(ins.instr): 1992 # diff = self._mk_align(max(-args[0].val - 8, 0)) 1993 1994 cursp += diff 1995 1996 #NOTICE: pcsp no need to update here 1997 if name == 'callq': 1998 cursp += 8 1999 2000 # update the max stack depth 2001 if cursp > maxsp: 2002 maxsp = cursp 2003 2004 # update pcsp 2005 if pcsp: 2006 pcsp.update(ins.size(pcsp.pc), diff) 2007 2008 # trace successful 2009 return maxsp, close 2010 2011 def get(self, key: str) -> Optional[int]: 2012 if key not in self.labels: 2013 raise SyntaxError('unresolved reference to name: ' + key) 2014 else: 2015 return self._find_label(key, itertools.repeat(0, len(self.blocks))) 2016 2017 def has(self, key: str) -> bool: 2018 return key in self.labels 2019 2020 def emit(self, buf: bytes, comments: str = ''): 2021 if not self.dead: 2022 self.block.body.append(RawInstr(len(buf), Instruction.encode(buf, comments or buf.hex()), buf)) 2023 2024 def lazy(self, size: int, func: Callable[[], int], comments: str = ''): 2025 if not self.dead: 2026 self.block.body.append(IntInstr(size, func, comments)) 2027 2028 def label(self, name: str): 2029 if name not in self.labels or self.labels[name].weak: 2030 self._decl(name, self._make(name)) 2031 else: 2032 raise SyntaxError('duplicated label: ' + name) 2033 2034 def instr(self, instr: Instruction): 2035 if not self.dead: 2036 if self._check_align(instr): 2037 return 2038 self._alloc_instr(instr) 2039 self._check_split(instr) 2040 2041 def stacksize(self, name: str) -> int: 2042 if name not in self.labels: 2043 raise SyntaxError('undefined function: ' + name) 2044 else: 2045 return self._trace_block(self.labels[name], None) 2046 2047 def pcsp(self, name: str, entry: int) -> int: 2048 if name not in self.labels: 2049 raise SyntaxError('undefined function: ' + name) 2050 else: 2051 pcsp = Pcsp(entry) 2052 self.labels[name].func = True 2053 return self._trace_block(self.labels[name], pcsp) 2054 2055 def debug(self, pos: int, inss: List[Instruction]): 2056 def inject(bb: BasicBlock) -> bool: 2057 if (not bb.func) and (bb.name not in self.funcs): 2058 return True 2059 nonlocal pos 2060 if pos >= len(bb.body): 2061 return 2062 for ins in inss: 2063 bb.body.insert(pos, ins) 2064 pos += 1 2065 visited = {} 2066 for _, bb in self.labels.items(): 2067 CodeSection._dfs_jump_first(bb, visited, inject) 2068 2069 STUB_NAME = '__native_entry__' 2070 STUB_SIZE = 67 2071 WITH_OFFS = os.getenv('ASM2ASM_DEBUG_OFFSET', '').lower() in ('1', 'yes', 'true') 2072 2073 class Assembler: 2074 out : List[str] 2075 subr : Dict[str, int] 2076 code : CodeSection 2077 vals : Dict[str, Union[str, int]] 2078 2079 def __init__(self): 2080 self.out = [] 2081 self.subr = {} 2082 self.vals = {} 2083 self.code = CodeSection() 2084 2085 def _get(self, v: str) -> int: 2086 if v not in self.vals: 2087 return self.code.get(v) 2088 elif isinstance(self.vals[v], int): 2089 return self.vals[v] 2090 else: 2091 ret = self.vals[v] = self._eval(self.vals[v]) 2092 return ret 2093 2094 def _eval(self, v: str) -> int: 2095 return Expression(v).eval(self._get) 2096 2097 def _emit(self, v: bytes, cmd: str): 2098 for i in range(0, len(v), 16): 2099 self.code.emit(v[i:i + 16], '%s %d, %s' % (cmd, len(v[i:i + 16]), repr(v[i:i + 16])[1:])) 2100 2101 def _limit(self, v: int, a: int, b: int) -> int: 2102 if not (a <= v <= b): 2103 raise SyntaxError('integer constant out of bound [%d, %d): %d' % (a, b, v)) 2104 else: 2105 return v 2106 2107 def _vfill(self, cmd: str, args: List[str]) -> Tuple[int, int]: 2108 if len(args) == 1: 2109 return self._limit(self._eval(args[0]), 1, 1 << 64), 0 2110 elif len(args) == 2: 2111 return self._limit(self._eval(args[0]), 1, 1 << 64), self._limit(self._eval(args[1]), 0, 255) 2112 else: 2113 raise SyntaxError(cmd + ' takes 1 ~ 2 arguments') 2114 2115 def _bytes(self, cmd: str, args: List[str], low: int, high: int, size: int): 2116 if len(args) != 1: 2117 raise SyntaxError(cmd + ' takes exact 1 argument') 2118 else: 2119 self.code.lazy(size, lambda: self._limit(self._eval(args[0]), low, high) & high, '%s %s' % (cmd, args[0])) 2120 2121 def _comment(self, msg: str): 2122 self.code.blocks[-1].body.append(CommentInstr(msg)) 2123 2124 def _cmd_nop(self, _: List[str]): 2125 pass 2126 2127 def _cmd_set(self, args: List[str]): 2128 if len(args) != 2: 2129 raise SyntaxError(".set takes exact 2 argument") 2130 elif not args[0].isidentifier(): 2131 raise SyntaxError(repr(args[0]) + " is not a valid identifier") 2132 else: 2133 key = args[0] 2134 val = args[1] 2135 self.vals[key] = val 2136 self._comment('.set ' + ', '.join(args)) 2137 # special case: clang-generated jump tables are always like '{block}_{table}' 2138 jt = val.find(CLANG_JUMPTABLE_LABLE) 2139 if jt > 0: 2140 tab = self.code.get_jmptab(val[jt:]) 2141 tab.append(self.code.get_block(val[:jt-1])) 2142 2143 def _cmd_byte(self, args: List[str]): 2144 self._bytes('.byte', args, -0x80, 0xff, 1) 2145 2146 def _cmd_word(self, args: List[str]): 2147 self._bytes('.word', args, -0x8000, 0xffff, 2) 2148 2149 def _cmd_long(self, args: List[str]): 2150 self._bytes('.long', args, -0x80000000, 0xffffffff, 4) 2151 2152 def _cmd_quad(self, args: List[str]): 2153 self._bytes('.quad', args, -0x8000000000000000, 0xffffffffffffffff, 8) 2154 2155 def _cmd_ascii(self, args: List[str]): 2156 if len(args) != 1: 2157 raise SyntaxError('.ascii takes exact 1 argument') 2158 else: 2159 self._emit(args[0].encode('latin-1'), '.ascii') 2160 2161 def _cmd_asciz(self, args: List[str]): 2162 if len(args) != 1: 2163 raise SyntaxError('.asciz takes exact 1 argument') 2164 else: 2165 self._emit(args[0].encode('latin-1') + b'\0', '.asciz') 2166 2167 def _cmd_space(self, args: List[str]): 2168 nb, fv = self._vfill('.space', args) 2169 self._emit(bytes([fv] * nb), '.space') 2170 2171 def _cmd_p2align(self, args: List[str]): 2172 if len(args) == 1: 2173 self.code.block.body.append(AlignmentInstr(self._eval(args[0]))) 2174 elif len(args) == 2: 2175 self.code.block.body.append(AlignmentInstr(self._eval(args[0]), self._eval(args[1]))) 2176 else: 2177 raise SyntaxError('.p2align takes 1 ~ 2 arguments') 2178 2179 @functools.cached_property 2180 def _commands(self) -> dict: 2181 return { 2182 '.set' : self._cmd_set, 2183 '.int' : self._cmd_long, 2184 '.long' : self._cmd_long, 2185 '.byte' : self._cmd_byte, 2186 '.quad' : self._cmd_quad, 2187 '.word' : self._cmd_word, 2188 '.hword' : self._cmd_word, 2189 '.short' : self._cmd_word, 2190 '.ascii' : self._cmd_ascii, 2191 '.asciz' : self._cmd_asciz, 2192 '.space' : self._cmd_space, 2193 '.globl' : self._cmd_nop, 2194 '.p2align' : self._cmd_p2align, 2195 '.section' : self._cmd_nop, 2196 '.data_region' : self._cmd_nop, 2197 '.build_version' : self._cmd_nop, 2198 '.end_data_region' : self._cmd_nop, 2199 '.subsections_via_symbols' : self._cmd_nop, 2200 } 2201 2202 @staticmethod 2203 def _is_rip_relative(op: Operand) -> bool: 2204 return isinstance(op, Memory) and \ 2205 op.base is not None and \ 2206 op.base.reg == 'rip' and \ 2207 op.index is None and \ 2208 isinstance(op.disp, Reference) 2209 2210 @staticmethod 2211 def _remove_comments(line: str, *, st: str = 'normal') -> str: 2212 for i, ch in enumerate(line): 2213 if st == 'normal' and ch == '/' : st = 'slcomm' 2214 elif st == 'normal' and ch == '\"' : st = 'string' 2215 elif st == 'normal' and ch in ('#', ';') : return line[:i] 2216 elif st == 'slcomm' and ch == '/' : return line[:i - 1] 2217 elif st == 'slcomm' : st = 'normal' 2218 elif st == 'string' and ch == '\"' : st = 'normal' 2219 elif st == 'string' and ch == '\\' : st = 'escape' 2220 elif st == 'escape' : st = 'string' 2221 else: 2222 return line 2223 2224 def _parse(self, src: List[str]): 2225 for line in src: 2226 line = self._remove_comments(line) 2227 line = line.strip() 2228 2229 # skip empty lines 2230 if not line: 2231 continue 2232 2233 # labels, resolve the offset 2234 if line[-1] == ':': 2235 self.code.label(line[:-1]) 2236 continue 2237 2238 # instructions 2239 if line[0] != '.': 2240 self.code.instr(Instruction.parse(line)) 2241 continue 2242 2243 # parse the command 2244 cmd = Command.parse(line) 2245 func = self._commands.get(cmd.cmd) 2246 2247 # handle the command 2248 if func is not None: 2249 func(cmd.args) 2250 else: 2251 raise SyntaxError('invalid assembly command: ' + cmd.cmd) 2252 2253 def _reloc(self, rip: int = 0): 2254 for block in self.code.blocks: 2255 for instr in block.body: 2256 rip += self._reloc_one(instr, rip) 2257 2258 def _reloc_one(self, instr: Instr, rip: int) -> int: 2259 if not isinstance(instr, (X86Instr, BranchInstr)): 2260 return instr.size(rip) 2261 elif instr.instr.is_branch_label and isinstance(instr.instr.operands[0], Label): 2262 return self._reloc_branch(instr.instr, rip) 2263 else: 2264 return instr.resize(self._reloc_normal(instr.instr, rip)) 2265 2266 def _reloc_branch(self, instr: Instruction, rip: int) -> int: 2267 instr.operands[0].resolve(self.code.get(instr.operands[0].name) - rip - instr.size) 2268 return instr.size 2269 2270 def _reloc_normal(self, instr: Instruction, rip: int) -> int: 2271 msg = [] 2272 ops = instr.operands 2273 2274 # relocate RIP relative operands 2275 for i, op in enumerate(ops): 2276 if self._is_rip_relative(op): 2277 if self.code.has(str(op.disp.ref)): 2278 self._reloc_static(op.disp, msg, rip + instr.size) 2279 else: 2280 raise SyntaxError('unresolved reference to name ' + str(op.disp.ref)) 2281 2282 # attach comments if any 2283 instr.comments = ', '.join(msg) or instr.comments 2284 return instr.size 2285 2286 def _reloc_static(self, ref: Reference, msg: List[str], rip: int): 2287 msg.append('%s+%d(%%rip)' % (ref.ref, ref.disp)) 2288 ref.resolve(self.code.get(str(ref.ref)) - rip) 2289 2290 def _declare(self, protos: PrototypeMap): 2291 if OUTPUT_RAW: 2292 self._declare_body_raw() 2293 else: 2294 self._declare_body() 2295 self._declare_functions(protos) 2296 2297 def _declare_body(self): 2298 self.out.append('TEXT ·%s(SB), NOSPLIT, $0' % STUB_NAME) 2299 self.out.append('\tNO_LOCAL_POINTERS') 2300 self._reloc() 2301 2302 # instruction buffer 2303 pc = 0 2304 ins = self.code.instrs 2305 2306 # dump every instruction 2307 for v in ins: 2308 self.out.append(('// +%d\n' % pc if WITH_OFFS else '') + v.formatted(pc)) 2309 pc += v.size(pc) 2310 2311 def _declare_body_raw(self): 2312 self._reloc() 2313 2314 # instruction buffer 2315 pc = 0 2316 ins = self.code.instrs 2317 2318 # dump every instruction 2319 for v in ins: 2320 self.out.append(v.raw_formatted(pc)) 2321 pc += v.size(pc) 2322 2323 def _declare_function(self, name: str, proto: Prototype): 2324 offs = 0 2325 subr = name[1:] 2326 addr = self.code.get(subr) 2327 self.subr[subr] = addr 2328 size = self.code.pcsp(subr, addr) 2329 2330 if OUTPUT_RAW: 2331 return 2332 2333 # function header and stack checking 2334 self.out.append('') 2335 self.out.append('TEXT ·%s(SB), NOSPLIT | NOFRAME, $0 - %d' % (name, proto.argspace)) 2336 self.out.append('\tNO_LOCAL_POINTERS') 2337 2338 # add stack check if needed 2339 if size != 0: 2340 self.out.append('') 2341 self.out.append('_entry:') 2342 self.out.append('\tMOVQ (TLS), R14') 2343 self.out.append('\tLEAQ -%d(SP), R12' % size) 2344 self.out.append('\tCMPQ R12, 16(R14)') 2345 self.out.append('\tJBE _stack_grow') 2346 2347 # function name 2348 self.out.append('') 2349 self.out.append('%s:' % subr) 2350 2351 # intialize all the arguments 2352 for arg in proto.args: 2353 offs += arg.size 2354 op, reg = REG_MAP[arg.creg.reg] 2355 self.out.append('\t%s %s+%d(FP), %s' % (op, arg.name, offs - arg.size, reg)) 2356 2357 # the function starts at zero 2358 if addr == 0 and proto.retv is None: 2359 self.out.append('\tJMP ·%s(SB) // %s' % (STUB_NAME, subr)) 2360 2361 # Go ASM completely ignores the offset of the JMP instruction, 2362 # so we need to use indirect jumps instead for tail-call elimination 2363 elif proto.retv is None: 2364 self.out.append('\tLEAQ ·%s+%d(SB), AX // %s' % (STUB_NAME, addr, subr)) 2365 self.out.append('\tJMP AX') 2366 2367 # normal functions, call the real function, and return the result 2368 else: 2369 self.out.append('\tCALL ·%s+%d(SB) // %s' % (STUB_NAME, addr, subr)) 2370 self.out.append('\t%s, %s+%d(FP)' % (' '.join(REG_MAP[proto.retv.creg.reg]), proto.retv.name, offs)) 2371 self.out.append('\tRET') 2372 2373 # add stack growing if needed 2374 if size != 0: 2375 self.out.append('') 2376 self.out.append('_stack_grow:') 2377 self.out.append('\tCALL runtime·morestack_noctxt<>(SB)') 2378 self.out.append('\tJMP _entry') 2379 2380 def _declare_functions(self, protos: PrototypeMap): 2381 for name, proto in sorted(protos.items()): 2382 if name[0] == '_': 2383 self._declare_function(name, proto) 2384 else: 2385 raise SyntaxError('function prototype must have a "_" prefix: ' + repr(name)) 2386 2387 def parse(self, src: List[str], proto: PrototypeMap): 2388 self.code.instr(Instruction('leaq', [Memory(Register('rip'), Immediate(-7), None), Register('rax')])) 2389 self.code.instr(Instruction('movq', [Register('rax'), Memory(Register('rsp'), Immediate(8), None)])) 2390 self.code.instr(Instruction('retq', [])) 2391 self._parse(src) 2392 # print("DEBUG...") 2393 # self.code.debug(0, [ 2394 # X86Instr(Instruction('int3', [])) 2395 # # X86Instr(Instruction('xorq', [Register('rax'), Register('rax')])), 2396 # # X86Instr(Instruction('movq', [Memory(Register('rax'), Immediate(0), None), Register('rax')])) 2397 # ]) 2398 self._declare(proto) 2399 2400 GOOS = { 2401 'aix', 2402 'android', 2403 'darwin', 2404 'dragonfly', 2405 'freebsd', 2406 'hurd', 2407 'illumos', 2408 'js', 2409 'linux', 2410 'nacl', 2411 'netbsd', 2412 'openbsd', 2413 'plan9', 2414 'solaris', 2415 'windows', 2416 'zos', 2417 } 2418 2419 GOARCH = { 2420 '386', 2421 'amd64', 2422 'amd64p32', 2423 'arm', 2424 'armbe', 2425 'arm64', 2426 'arm64be', 2427 'ppc64', 2428 'ppc64le', 2429 'mips', 2430 'mipsle', 2431 'mips64', 2432 'mips64le', 2433 'mips64p32', 2434 'mips64p32le', 2435 'ppc', 2436 'riscv', 2437 'riscv64', 2438 's390', 2439 's390x', 2440 'sparc', 2441 'sparc64', 2442 'wasm', 2443 } 2444 2445 def make_subr_filename(name: str) -> str: 2446 name = os.path.basename(name) 2447 base = os.path.splitext(name)[0].rsplit('_', 2) 2448 2449 # construct the new name 2450 if base[-1] in GOOS: 2451 return '%s_subr_%s.go' % ('_'.join(base[:-1]), base[-1]) 2452 elif base[-1] not in GOARCH: 2453 return '%s_subr.go' % '_'.join(base) 2454 elif len(base) > 2 and base[-2] in GOOS: 2455 return '%s_subr_%s_%s.go' % ('_'.join(base[:-2]), base[-2], base[-1]) 2456 else: 2457 return '%s_subr_%s.go' % ('_'.join(base[:-1]), base[-1]) 2458 2459 def main(): 2460 src = [] 2461 asm = Assembler() 2462 2463 2464 # check for arguments 2465 if len(sys.argv) < 3: 2466 print('* usage: %s [-r|-d] <output-file> <clang-asm> ...' % sys.argv[0], file = sys.stderr) 2467 sys.exit(1) 2468 2469 # check if optional flag is enabled 2470 global OUTPUT_RAW 2471 OUTPUT_RAW = False 2472 if len(sys.argv) >= 4: 2473 i = 0 2474 while i<len(sys.argv): 2475 flag = sys.argv[i] 2476 if flag == '-r': 2477 OUTPUT_RAW = True 2478 for j in range(i, len(sys.argv)-1): 2479 sys.argv[j] = sys.argv[j + 1] 2480 sys.argv.pop() 2481 continue 2482 i += 1 2483 2484 # parse the prototype 2485 with open(os.path.splitext(sys.argv[1])[0] + '.go', 'r', newline = None) as fp: 2486 pkg, proto = PrototypeMap.parse(fp.read()) 2487 2488 # read all the sources, and combine them together 2489 for fn in sys.argv[2:]: 2490 with open(fn, 'r', newline = None) as fp: 2491 src.extend(fp.read().splitlines()) 2492 2493 # convert the original sources 2494 if OUTPUT_RAW: 2495 asm.out.append('// +build amd64') 2496 asm.out.append('// Code generated by asm2asm, DO NOT EDIT.') 2497 asm.out.append('') 2498 asm.out.append('package %s' % pkg) 2499 asm.out.append('') 2500 ## native text 2501 asm.out.append('var Text%s = []byte{' % STUB_NAME) 2502 else: 2503 asm.out.append('// +build !noasm !appengine') 2504 asm.out.append('// Code generated by asm2asm, DO NOT EDIT.') 2505 asm.out.append('') 2506 asm.out.append('#include "go_asm.h"') 2507 asm.out.append('#include "funcdata.h"') 2508 asm.out.append('#include "textflag.h"') 2509 asm.out.append('') 2510 2511 asm.parse(src, proto) 2512 2513 if OUTPUT_RAW: 2514 asrc = os.path.splitext(sys.argv[1])[0] 2515 asrc = asrc[:asrc.rfind('_')] + '_text_amd64.go' 2516 else: 2517 asrc = os.path.splitext(sys.argv[1])[0] + '.s' 2518 2519 # save the converted result 2520 with open(asrc, 'w') as fp: 2521 for line in asm.out: 2522 print(line, file = fp) 2523 if OUTPUT_RAW: 2524 print('}', file = fp) 2525 2526 # calculate the subroutine stub file name 2527 subr = make_subr_filename(sys.argv[1]) 2528 subr = os.path.join(os.path.dirname(sys.argv[1]), subr) 2529 2530 # save the compiled code stub 2531 with open(subr, 'w') as fp: 2532 print('// +build !noasm !appengine', file = fp) 2533 print('// Code generated by asm2asm, DO NOT EDIT.', file = fp) 2534 print(file = fp) 2535 print('package %s' % pkg, file = fp) 2536 2537 # also save the actual function addresses if any 2538 if not asm.subr: 2539 return 2540 2541 if OUTPUT_RAW: 2542 print(file = fp) 2543 print('import (\n\t`github.com/goshafaq/sonic/loader`\n)', file = fp) 2544 2545 # dump every entry for all functions 2546 print(file = fp) 2547 print('const (', file = fp) 2548 for name in asm.code.funcs.keys(): 2549 addr = asm.code.get(name) 2550 if addr is not None: 2551 print(f' _entry_{name} = %d' % addr, file = fp) 2552 print(')', file = fp) 2553 2554 # dump max stack depth for all functions 2555 print(file = fp) 2556 print('const (', file = fp) 2557 for name in asm.code.funcs.keys(): 2558 print(' _stack_%s = %d' % (name, asm.code.stacksize(name)), file = fp) 2559 print(')', file = fp) 2560 2561 # dump every text size for all functions 2562 print(file = fp) 2563 print('const (', file = fp) 2564 for name, pcsp in asm.code.funcs.items(): 2565 if pcsp is not None: 2566 # print(f'before {name} optimize {pcsp}') 2567 pcsp.optimize() 2568 # print(f'after {name} optimize {pcsp}') 2569 print(f' _size_{name} = %d' % (pcsp.maxpc - pcsp.entry), file = fp) 2570 print(')', file = fp) 2571 2572 # dump every pcsp for all functions 2573 print(file = fp) 2574 print('var (', file = fp) 2575 for name, pcsp in asm.code.funcs.items(): 2576 if pcsp is not None: 2577 print(f' _pcsp_{name} = %s' % pcsp, file = fp) 2578 print(')', file = fp) 2579 2580 # insert native entry info 2581 print(file = fp) 2582 print('var Funcs = []loader.CFunc{', file = fp) 2583 print(' {"%s", 0, %d, 0, nil},' % (STUB_NAME, STUB_SIZE), file = fp) 2584 # dump every native function info for all functions 2585 for name in asm.code.funcs.keys(): 2586 print(' {"%s", _entry_%s, _size_%s, _stack_%s, _pcsp_%s},' % (name, name, name, name, name), file = fp) 2587 print('}', file = fp) 2588 2589 else: 2590 # native entry for entry function 2591 print(file = fp) 2592 print('//go:nosplit', file = fp) 2593 print('//go:noescape', file = fp) 2594 print('//goland:noinspection ALL', file = fp) 2595 print('func %s() uintptr' % STUB_NAME, file = fp) 2596 2597 # dump exported function entry for exported functions 2598 print(file = fp) 2599 print('var (', file = fp) 2600 mlen = max(len(s) for s in asm.subr) 2601 for name, entry in asm.subr.items(): 2602 print(' _subr_%s uintptr = %s() + %d' % (name.ljust(mlen, ' '), STUB_NAME, entry), file = fp) 2603 print(')', file = fp) 2604 2605 # dump max stack depth for exported functions 2606 print(file = fp) 2607 print('const (', file = fp) 2608 for name in asm.subr.keys(): 2609 print(' _stack_%s = %d' % (name, asm.code.stacksize(name)), file = fp) 2610 print(')', file = fp) 2611 2612 # assign subroutine offsets to '_' to mute the "unused" warnings 2613 print(file = fp) 2614 print('var (', file = fp) 2615 for name in asm.subr: 2616 print(' _ = _subr_%s' % name, file = fp) 2617 print(')', file = fp) 2618 2619 # dump every constant 2620 print(file = fp) 2621 print('const (', file = fp) 2622 for name in asm.subr: 2623 print(' _ = _stack_%s' % name, file = fp) 2624 else: 2625 print(')', file = fp) 2626 2627 if __name__ == '__main__': 2628 main()