github.com/bytedance/sonic@v1.11.7-0.20240517092252-d2edb31b167b/tools/asm2arm/arm.py (about) 1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 4 import os 5 import sys 6 import string 7 import argparse 8 import itertools 9 import functools 10 11 from typing import Any 12 from typing import Dict 13 from typing import List 14 from typing import Type 15 from typing import Tuple 16 from typing import Union 17 from typing import Callable 18 from typing import Iterable 19 from typing import Optional 20 21 import mcasm 22 23 class InstrStreamer(mcasm.Streamer): 24 data: bytes 25 instr: Optional[mcasm.mc.Instruction] 26 fixups: List[mcasm.mc.Fixup] 27 28 def __init__(self): 29 self.data = b'' 30 self.instr = None 31 self.fixups = [] 32 super().__init__() 33 34 def unhandled_event(self, name: str, base_impl, *args, **kwargs): 35 if name == 'emit_instruction': 36 self.instr = args[1] 37 self.data = args[2] 38 self.fixups = args[3] 39 return super().unhandled_event(name, base_impl, *args, **kwargs) 40 41 ### Instruction Parser (GAS Syntax) ### 42 class Token: 43 tag: int 44 val: Union[int, str] 45 46 def __init__(self, tag: int, val: Union[int, str]): 47 self.val = val 48 self.tag = tag 49 50 @classmethod 51 def end(cls): 52 return cls(TOKEN_END, '') 53 54 @classmethod 55 def reg(cls, reg: str): 56 return cls(TOKEN_REG, reg) 57 58 @classmethod 59 def imm(cls, imm: int): 60 return cls(TOKEN_IMM, imm) 61 62 @classmethod 63 def num(cls, num: int): 64 return cls(TOKEN_NUM, num) 65 66 @classmethod 67 def name(cls, name: str): 68 return cls(TOKEN_NAME, name) 69 70 @classmethod 71 def punc(cls, punc: str): 72 return cls(TOKEN_PUNC, punc) 73 74 def __repr__(self): 75 if self.tag == TOKEN_END: 76 return '<END>' 77 elif self.tag == TOKEN_REG: 78 return '<REG %s>' % self.val 79 elif self.tag == TOKEN_IMM: 80 return '<IMM %d>' % self.val 81 elif self.tag == TOKEN_NUM: 82 return '<NUM %d>' % self.val 83 elif self.tag == TOKEN_NAME: 84 return '<NAME %s>' % repr(self.val) 85 elif self.tag == TOKEN_PUNC: 86 return '<PUNC %s>' % repr(self.val) 87 else: 88 return '<UNK:%d %r>' % (self.tag, self.val) 89 90 class Label: 91 name: str 92 offs: Optional[int] 93 94 def __init__(self, name: str): 95 self.name = name 96 self.offs = None 97 98 def __str__(self): 99 return self.name 100 101 def __repr__(self): 102 if self.offs is None: 103 return '{LABEL %s (unresolved)}' % self.name 104 else: 105 return '{LABEL %s (offset: %d)}' % (self.name, self.offs) 106 107 def resolve(self, offs: int): 108 self.offs = offs 109 110 class Index: 111 base : 'Register' 112 scale : int 113 114 def __init__(self, base: 'Register', scale: int = 1): 115 self.base = base 116 self.scale = scale 117 118 def __str__(self): 119 if self.scale == 1: 120 return ',%s' % self.base 121 elif self.scale >= 2: 122 return ',%s,%d' % (self.base, self.scale) 123 else: 124 raise RuntimeError('invalid parser state: invalid scale') 125 126 def __repr__(self): 127 if self.scale == 1: 128 return repr(self.base) 129 elif self.scale >= 2: 130 return '%d * %r' % (self.scale, self.base) 131 else: 132 raise RuntimeError('invalid parser state: invalid scale') 133 134 class Memory: 135 base : Optional['Register'] 136 disp : Optional['Displacement'] 137 index : Optional[Index] 138 139 def __init__(self, base: Optional['Register'], disp: Optional['Displacement'], index: Optional[Index]): 140 self.base = base 141 self.disp = disp 142 self.index = index 143 self._validate() 144 145 def __str__(self): 146 return '%s(%s%s)' % ( 147 '' if self.disp is None else self.disp, 148 '' if self.base is None else self.base, 149 '' if self.index is None else self.index 150 ) 151 152 def __repr__(self): 153 return '{MEM %r%s%s}' % ( 154 '' if self.base is None else self.base, 155 '' if self.index is None else ' + ' + repr(self.index), 156 '' if self.disp is None else ' + ' + repr(self.disp) 157 ) 158 159 def _validate(self): 160 if self.base is None and self.index is None: 161 raise SyntaxError('either base or index must be specified') 162 163 class Register: 164 reg: str 165 166 def __init__(self, reg: str): 167 self.reg = reg.lower() 168 169 def __str__(self): 170 return '%' + self.reg 171 172 def __repr__(self): 173 return '{REG %s}' % self.reg 174 175 class Immediate: 176 val: int 177 ref: str 178 179 def __init__(self, val: int): 180 self.ref = '' 181 self.val = val 182 183 def __str__(self): 184 return '$%d' % self.val 185 186 def __repr__(self): 187 return '{IMM bin:%s, oct:%s, dec:%d, hex:%s}' % ( 188 bin(self.val)[2:], 189 oct(self.val)[2:], 190 self.val, 191 hex(self.val)[2:], 192 ) 193 194 class Reference: 195 ref: str 196 disp: int 197 off: Optional[int] 198 199 def __init__(self, ref: str, disp: int = 0): 200 self.ref = ref 201 self.disp = disp 202 self.off = None 203 204 def __str__(self): 205 if self.off is None: 206 return self.ref 207 else: 208 return '$' + str(self.off) 209 210 def __repr__(self): 211 if self.off is None: 212 return '{REF %s + %d (unresolved)}' % (self.ref, self.disp) 213 else: 214 return '{REF %s + %d (offset: %d)}' % (self.ref, self.disp, self.off) 215 216 @property 217 def offset(self) -> int: 218 if self.off is None: 219 raise SyntaxError('unresolved reference to ' + repr(self.ref)) 220 else: 221 return self.off 222 223 def resolve(self, off: int): 224 self.off = self.disp + off 225 226 Operand = Union[ 227 Label, 228 Memory, 229 Register, 230 Immediate, 231 Reference, 232 ] 233 234 Displacement = Union[ 235 Immediate, 236 Reference, 237 ] 238 239 TOKEN_END = 0 240 TOKEN_REG = 1 241 TOKEN_IMM = 2 242 TOKEN_NUM = 3 243 TOKEN_NAME = 4 244 TOKEN_PUNC = 5 245 246 ARM_ADRP_IMM_BIT_SIZE = 21 247 ARM_ADR_WIDTH = 1024 * 1024 248 249 class Instruction: 250 comments: str 251 mnemonic: str 252 asm_code: str 253 data: bytes 254 instr: Optional[mcasm.mc.Instruction] 255 fixups: List[mcasm.mc.Fixup] 256 offs_: Optional[int] 257 ADRP_label: Optional[str] 258 text_label: Optional[str] 259 back_label: Optional[str] 260 ADR_instr: Optional[str] 261 adrp_asm: Optional[str] 262 is_adrp: bool 263 264 def __init__(self, line: str, adrp_count=0): 265 self.comments = '' 266 self.offs_ = None 267 self.is_adrp = False 268 self.asm = mcasm.Assembler('aarch64-apple-macos11') 269 270 self.parse(line, adrp_count) 271 272 def __str__(self): 273 return self.asm_code 274 275 def __repr__(self): 276 return '{INSTR %s}' % ( self.asm_code ) 277 278 @property 279 def jmptab(self) -> Optional[str]: 280 if self.is_adrp and self.label_name.find(CLANG_JUMPTABLE_LABLE) != -1: 281 return self.label_name 282 283 @property 284 def size(self) -> int: 285 return len(self.data) 286 287 @functools.cached_property 288 def label_name(self) -> Optional[str]: 289 if len(self.fixups) > 1: 290 raise RuntimeError('has more than 1 fixup: ' + self.asm_code) 291 if self.need_reloc: 292 if self.mnemonic == 'adr': 293 return self.fixups[0].value.sub_expr.symbol.name 294 else: 295 return self.fixups[0].value.symbol.name 296 else: 297 return None 298 299 @functools.cached_property 300 def is_branch(self) -> bool: 301 return self.instr.desc.is_branch or self.is_invoke 302 303 @functools.cached_property 304 def is_return(self) -> bool: 305 return self.instr.desc.is_return 306 307 @functools.cached_property 308 def is_jmpq(self) -> bool: 309 # return self.mnemonic == 'br' 310 return False 311 312 @functools.cached_property 313 def is_jmp(self) -> bool: 314 return self.mnemonic == 'b' 315 316 @functools.cached_property 317 def is_invoke(self) -> bool: 318 return self.instr.desc.is_call 319 320 @property 321 def is_branch_label(self) -> bool: 322 return self.is_branch and (len(self.fixups) != 0) 323 324 @property 325 def need_reloc(self) -> bool: 326 return (len(self.fixups) != 0) 327 328 def set_label_offset(self, off): 329 # arm64 330 self.offs_ = off + 4 331 332 # def _encode_normal_instr(self) -> str: 333 # return self.encode(self.data, self.asm_code) 334 335 @functools.cached_property 336 def encoded(self) -> str: 337 if self.need_reloc: 338 return self._encode_reloc_instr() 339 else: 340 return self._encode_normal_instr() 341 342 def _check_offs_is_valid(self, bit_size: int): 343 if abs(self.offs_) > (1 << bit_size): 344 raise RuntimeError('offset is too larger, [assembly]: %s, [offset]: %d, [valid off size]: %d' 345 % (self.asm_code, self.offs_, self.fixups[0].kind_info.bit_size)) 346 347 def _encode_adr(self): 348 buf = int.from_bytes(self.data, byteorder='little') 349 bit_size = ARM_ADRP_IMM_BIT_SIZE 350 351 self._check_offs_is_valid(bit_size) 352 353 # adrp op: | op | immlo | 1 0 0 0 0 | immhi | Rd | 354 # |31 |30 29|28 24|23 5|4 0| 355 imm_lo = (self.offs_ << 29) & 0x60000000 356 imm_hi = (self.offs_ << 3) & 0x00FFFFE0 357 encode_data = (buf + imm_lo + imm_hi).to_bytes(4, byteorder='little') 358 self.data = encode_data 359 # return self.encode(encode_data, '%s $%s(%%rip)' % (str(self), self.offs_)) 360 361 def _encode_rel32(self): 362 if self.mnemonic == 'adrp' or self.mnemonic == 'adr': 363 return self._encode_adr() 364 buf = int.from_bytes(self.data, byteorder='little') 365 366 imm = self.offs_ 367 imm_size = self.fixups[0].kind_info.bit_size 368 imm_offset = self.fixups[0].kind_info.bit_offset 369 if self.fixups[0].kind_info.is_pc_rel == 1: 370 # except adr and adrp, other PC-releative instructions need times 4 371 imm = imm >> 2 372 # immediate bit size has 1-bit for sign 373 self._check_offs_is_valid(imm_size - 1 + 2) 374 else: 375 self._check_offs_is_valid(imm_size) 376 377 imm = imm << imm_offset 378 mask = (0x1 << (imm_size + imm_offset)) - 1 379 buf = buf | (imm & mask) 380 buf = buf.to_bytes(4, byteorder='little') 381 self.data = buf 382 # return self.encode(buf, '%s $%s(%%rip)' % (str(self), self.offs_)) 383 384 def _encode_page(self): 385 if self.mnemonic != 'adrp': 386 raise RuntimeError("not adrp instruction: %s" % self.asm_code) 387 self.offs_ = self.offs_ >> 12 388 return self._encode_rel32() 389 390 def _encode_pageoff(self): 391 self.offs_ = 0 392 return self._encode_rel32() 393 394 def _fixup_rel32(self): 395 if self.offs_ is None: 396 raise RuntimeError('unresolved label %s' % self.label_name) 397 398 if self.mnemonic == 'adr': 399 self._encode_adr() 400 elif self.fixups[0].value.variant_kind == mcasm.mc.SymbolRefExpr.VariantKind.PAGEOFF: 401 self._encode_pageoff() 402 elif self.fixups[0].value.variant_kind == mcasm.mc.SymbolRefExpr.VariantKind.PAGE: 403 self._encode_page() 404 else: 405 self._encode_rel32() 406 407 def _encode_reloc_instr(self) -> str: 408 self._fixup_rel32() 409 return self.encode(self.data, '%s $%s(%%rip)' % (str(self), self.offs_)) 410 411 def _encode_normal_instr(self) -> str: 412 return self.encode(self.data, str(self)) 413 414 def _raw_instr(self) -> bytes: 415 if self.need_reloc: 416 self._fixup_rel32() 417 return self.data 418 419 def _fixup_adrp(self, line: str, adrp_count: int) -> str: 420 reg = line.split()[1].split(',')[0] 421 self.text_label = line.split()[2].split('@')[0] 422 self.ADRP_label = self.text_label + '_' + reg + '_' + str(adrp_count) 423 self.back_label = '_back_adrp_' + str(adrp_count) 424 self.ADR_instr = 'adr ' + reg + ', ' + self.text_label 425 self.adrp_asm = line 426 self.is_adrp = True 427 line = 'b ' + self.ADRP_label 428 self.asm_code = line + ' // ' + self.adrp_asm 429 430 return line 431 432 def _parse_by_mcasm(self, line: str): 433 streamer = InstrStreamer() 434 # self.asm.assemble(streamer, line, MCPU="", features_str="") 435 self.asm.assemble(streamer, line) 436 if streamer.instr is None: 437 raise RuntimeError('cannot parse assembly: %s' % line) 438 self.instr = streamer.instr 439 440 # instead of short jump instruction 441 self.data = streamer.data 442 443 self.fixups = streamer.fixups 444 self.mnemonic = line.split()[0] 445 446 def convert_to_adr(self): 447 self.is_adrp = True 448 adr_asm = self.adrp_asm.replace('adrp', 'adr') 449 self.asm_code = adr_asm + ' // ' + self.adrp_asm 450 # self._parse_by_mcasm(adr_asm) 451 return self.asm_code 452 453 def parse(self, line: str, adrp_count: int): 454 # machine code 455 menmonic = line.split()[0] 456 457 self.ADRP_label = None 458 self.text_label = None 459 # turn adrp to jmp 460 if (menmonic == 'adrp'): 461 line = self.convert_to_adr() 462 else: 463 self.asm_code = line 464 465 self._parse_by_mcasm(line) 466 467 @staticmethod 468 def encode(buf: bytes, comments: str = '') -> str: 469 i = 0 470 r = [] 471 n = len(buf) 472 473 # @debug 474 # while i < n - 3: 475 # r.append('%08x' % int.from_bytes(buf[i:i + 4], 'little')) 476 # i += 4 477 # return '\n\t'.join(r) 478 479 if (n % 4 != 0): 480 raise RuntimeError("Unkown instruction which not encoding 4 bytes: %s " % comments, buf) 481 482 while i < n - 3: 483 r.append('WORD $0x%08x' % int.from_bytes(buf[i:i + 4], 'little')) 484 i += 4 485 486 # join them together, and attach the comment if any 487 if not comments: 488 return '; '.join(r) 489 else: 490 return '%s // %s' % ('; '.join(r), comments) 491 492 Reg = Optional[Register] 493 Disp = Optional[Displacement] 494 495 ### Prototype Parser ### 496 497 ARGS_ORDER_C = [ 498 Register('x0'), 499 Register('x1'), 500 Register('x2'), 501 Register('x3'), 502 Register('x4'), 503 Register('x5'), 504 Register('x6'), 505 Register('x7'), 506 ] 507 508 ARGS_ORDER_GO = [ 509 Register('R0'), 510 Register('R1'), 511 Register('R2'), 512 Register('R3'), 513 Register('R4'), 514 Register('R5'), 515 Register('R6'), 516 Register('R7'), 517 ] 518 519 FPARGS_ORDER = [ 520 Register('D0'), 521 Register('D1'), 522 Register('D2'), 523 Register('D3'), 524 Register('D4'), 525 Register('D5'), 526 Register('D6'), 527 Register('D7'), 528 ] 529 530 class Parameter: 531 name : str 532 size : int 533 creg : Register 534 goreg: Register 535 536 def __init__(self, name: str, size: int, reg: Register, goreg: Register): 537 self.creg = reg 538 self.goreg = reg 539 self.name = name 540 self.size = size 541 542 def __repr__(self): 543 return '<ARG %s(%d): %s>' % (self.name, self.size, self.creg) 544 545 class Pcsp: 546 entry: int 547 maxpc: int 548 out : List[Tuple[int, int]] 549 pc : int 550 sp : int 551 552 def __init__(self, entry: int): 553 self.out = [] 554 self.maxpc = entry 555 self.entry = entry 556 self.pc = entry 557 self.sp = 0 558 559 def __str__(self) -> str: 560 ret = '[][2]uint32{\n' 561 for pc, sp in self.out: 562 ret += ' {%d, %d},\n' % (pc, sp) 563 return ret + ' }' 564 565 def optimize(self): 566 # push the last record 567 self.out.append((self.pc - self.entry, self.sp)) 568 # sort by pc 569 self.out.sort(key=lambda x: x[0]) 570 # NOTICE: first pair {1, 0} to be compitable with golang 571 tmp = [(1, 0)] 572 lpc, lsp = 0, -1 573 for pc, sp in self.out: 574 # sp changed, push new record 575 if pc != lpc and sp != lsp: 576 tmp.append((pc, sp)) 577 # sp unchanged, replace with the higher pc 578 if pc != lpc and sp == lsp: 579 if len(tmp) > 0: 580 tmp.pop(-1) 581 tmp.append((pc, sp)) 582 583 lpc, lsp = pc, sp 584 self.out = tmp 585 586 def update(self, dpc: int, dsp: int): 587 self.out.append((self.pc - self.entry, self.sp)) 588 self.pc += dpc 589 self.sp += dsp 590 if self.pc > self.maxpc: 591 self.maxpc = self.pc 592 593 class Prototype: 594 args: List[Parameter] 595 retv: Optional[Parameter] 596 597 def __init__(self, retv: Optional[Parameter], args: List[Parameter]): 598 self.retv = retv 599 self.args = args 600 601 def __repr__(self): 602 if self.retv is None: 603 return '<PROTO (%s)>' % repr(self.args) 604 else: 605 return '<PROTO (%r) -> %r>' % (self.args, self.retv) 606 607 @property 608 def argspace(self) -> int: 609 return sum( 610 [v.size for v in self.args], 611 (0 if self.retv is None else self.retv.size) 612 ) 613 614 @property 615 def inputspace(self) -> int: 616 return sum([v.size for v in self.args]) 617 618 class PrototypeMap(Dict[str, Prototype]): 619 @staticmethod 620 def _dv(c: str) -> int: 621 if c == '(': 622 return 1 623 elif c == ')': 624 return -1 625 else: 626 return 0 627 628 @staticmethod 629 def _tk(s: str, p: str) -> bool: 630 return s.startswith(p) and (s == p or s[len(p)].isspace()) 631 632 @classmethod 633 def _punc(cls, s: str) -> bool: 634 return s in cls.__puncs_ 635 636 @staticmethod 637 def _err(msg: str) -> SyntaxError: 638 return SyntaxError( 639 msg + ', ' + 640 'the parser integrated in this tool is just a text-based parser, ' + 641 'so please keep the companion .go file as simple as possible and do not use defined types' 642 ) 643 644 @staticmethod 645 def _align(nb: int) -> int: 646 return (((nb - 1) >> 3) + 1) << 3 647 648 @classmethod 649 def _retv(cls, ret: str) -> Tuple[str, int, Register, Register]: 650 name, size, xmm = cls._args(ret) 651 reg = Register('d0') if xmm else Register('x0') 652 return name, size, reg, reg 653 654 @classmethod 655 def _args(cls, arg: str, sv: str = '') -> Tuple[str, int, bool]: 656 while True: 657 if not arg: 658 raise SyntaxError('missing type for parameter: ' + sv) 659 elif arg[0] != '_' and not arg[0].isalnum(): 660 return (sv,) + cls._size(arg.strip()) 661 elif not sv and arg[0].isdigit(): 662 raise SyntaxError('invalid character: ' + repr(arg[0])) 663 else: 664 sv += arg[0] 665 arg = arg[1:] 666 667 @classmethod 668 def _size(cls, name: str) -> Tuple[int, bool]: 669 if name[0] == '*': 670 return cls._align(8), False 671 elif name in ('int8', 'uint8', 'byte', 'bool'): 672 return cls._align(1), False 673 elif name in ('int16', 'uint16'): 674 return cls._align(2), False 675 elif name == 'float32': 676 return cls._align(4), True 677 elif name in ('int32', 'uint32', 'rune'): 678 return cls._align(4), False 679 elif name == 'float64': 680 return cls._align(8), True 681 elif name in ('int64', 'uint64', 'uintptr', 'int', 'Pointer', 'unsafe.Pointer'): 682 return cls._align(8), False 683 else: 684 raise cls._err('unrecognized type "%s"' % name) 685 686 @classmethod 687 def _func(cls, src: List[str], idx: int, depth: int = 0) -> Tuple[str, int]: 688 for i in range(idx, len(src)): 689 for x in map(cls._dv, src[i]): 690 if depth + x >= 0: 691 depth += x 692 else: 693 raise cls._err('encountered ")" more than "(" on line %d' % (i + 1)) 694 else: 695 if depth == 0: 696 return ' '.join(src[idx:i + 1]), i + 1 697 else: 698 raise cls._err('unexpected EOF when parsing function signatures') 699 700 @classmethod 701 def parse(cls, src: str) -> Tuple[str, 'PrototypeMap']: 702 idx = 0 703 pkg = '' 704 ret = PrototypeMap() 705 buf = src.splitlines() 706 707 # scan through all the lines 708 while idx < len(buf): 709 line = buf[idx] 710 line = line.strip() 711 712 # skip empty lines 713 if not line: 714 idx += 1 715 continue 716 717 # check for package name 718 if cls._tk(line, 'package'): 719 idx, pkg = idx + 1, line[7:].strip().split()[0] 720 continue 721 722 if OUTPUT_RAW: 723 724 # extract funcname like "[var ]{funcname} = func(..." 725 end = line.find('func(') 726 if end == -1: 727 idx += 1 728 continue 729 name = line[:end].strip() 730 if name.startswith('var '): 731 name = name[4:].strip() 732 733 # function names must be identifiers 734 if not name.isidentifier(): 735 raise cls._err('invalid function prototype: ' + name) 736 737 # register a empty prototype 738 ret[name] = Prototype(None, []) 739 idx += 1 740 741 else: 742 743 # only cares about those functions that does not have bodies 744 if line[-1] == '{' or not cls._tk(line, 'func'): 745 idx += 1 746 continue 747 748 # prevent type-aliasing primitive types into other names 749 if cls._tk(line, 'type'): 750 raise cls._err('please do not declare any type with in the companion .go file') 751 752 # find the next function declaration 753 decl, pos = cls._func(buf, idx) 754 func, idx = decl[4:].strip(), pos 755 756 # find the beginning '(' 757 nd = 1 758 pos = func.find('(') 759 760 # must have a '(' 761 if pos == -1: 762 raise cls._err('invalid function prototype: ' + decl) 763 764 # extract the name and signature 765 args = '' 766 name = func[:pos].strip() 767 func = func[pos + 1:].strip() 768 769 # skip the method declaration 770 if not name: 771 continue 772 773 # function names must be identifiers 774 if not name.isidentifier(): 775 raise cls._err('invalid function prototype: ' + decl) 776 777 # extract the argument list 778 while nd and func: 779 nch = func[0] 780 func = func[1:] 781 782 # adjust the nesting level 783 nd += cls._dv(nch) 784 args += nch 785 786 # check for EOF 787 if not nd: 788 func = func.strip() 789 else: 790 raise cls._err('unexpected EOF when parsing function prototype: ' + decl) 791 792 # check for multiple returns 793 if ',' in func: 794 raise cls._err('can only return a single value (detected by looking for "," within the return list)') 795 796 # check for return signature 797 if not func: 798 retv = None 799 elif func[0] == '(' and func[-1] == ')': 800 retv = Parameter(*cls._retv(func[1:-1])) 801 else: 802 raise SyntaxError('badly formatted return argument (please use parenthesis and proper arguments naming): ' + func) 803 804 # extract the argument list 805 if not args[:-1]: 806 args, alens, axmm = [], [], [] 807 else: 808 args, alens, axmm = list(zip(*[cls._args(v.strip()) for v in args[:-1].split(',')])) 809 810 # check for the result 811 cregs = [] 812 goregs = [] 813 idxs = [0, 0] 814 815 # split the integer & floating point registers 816 for xmm in axmm: 817 key = 0 if xmm else 1 818 seq = FPARGS_ORDER if xmm else ARGS_ORDER_C 819 goseq = FPARGS_ORDER if xmm else ARGS_ORDER_GO 820 821 # check the argument count 822 if idxs[key] >= len(seq): 823 raise cls._err("too many arguments, consider pack some into a pointer") 824 825 # add the register 826 cregs.append(seq[idxs[key]]) 827 goregs.append(goseq[idxs[key]]) 828 idxs[key] += 1 829 830 # register the prototype 831 ret[name] = Prototype(retv, [ 832 Parameter(arg, size, creg, goreg) 833 for arg, size, creg, goreg in zip(args, alens, cregs, goregs) 834 ]) 835 836 # all done 837 return pkg, ret 838 839 ### Assembly Source Parser ### 840 841 ESC_IDLE = 0 # escape parser is idleing 842 ESC_ISTR = 1 # currently inside a string 843 ESC_BKSL = 2 # encountered backslash, prepare for escape sequences 844 ESC_HEX0 = 3 # expect the first hexadecimal character of a "\x" escape 845 ESC_HEX1 = 4 # expect the second hexadecimal character of a "\x" escape 846 ESC_OCT1 = 5 # expect the second octal character of a "\000" escape 847 ESC_OCT2 = 6 # expect the third octal character of a "\000" escape 848 849 class Command: 850 cmd : str 851 args : List[Union[str, bytes]] 852 853 def __init__(self, cmd: str, args: List[Union[str, bytes]]): 854 self.cmd = cmd 855 self.args = args 856 857 def __repr__(self): 858 return '<CMD %s %s>' % (self.cmd, ', '.join(map(repr, self.args))) 859 860 @classmethod 861 def parse(cls, src: str) -> 'Command': 862 val = src.split(None, 1) 863 cmd = val[0] 864 865 # no parameters 866 if len(val) == 1: 867 return cls(cmd, []) 868 869 # extract the argument string 870 idx = 0 871 esc = 0 872 pos = None 873 args = [] 874 vstr = val[1] 875 876 # scan through the whole string 877 while idx < len(vstr): 878 nch = vstr[idx] 879 idx += 1 880 881 # mark the start of the argument 882 if pos is None: 883 pos = idx - 1 884 885 # encountered the delimiter outside of a string 886 if nch == ',' and esc == ESC_IDLE: 887 pos, p = None, pos 888 args.append(vstr[p:idx - 1].strip()) 889 890 # start of a string 891 elif nch == '"' and esc == ESC_IDLE: 892 esc = ESC_ISTR 893 894 # end of string 895 elif nch == '"' and esc == ESC_ISTR: 896 esc = ESC_IDLE 897 pos, p = None, pos 898 args.append(vstr[p:idx].strip()[1:-1].encode('utf-8').decode('unicode_escape')) 899 900 # escape characters 901 elif nch == '\\' and esc == ESC_ISTR: 902 esc = ESC_BKSL 903 904 # hexadecimal escape characters (3 chars) 905 elif esc == ESC_BKSL and nch == 'x': 906 esc = ESC_HEX0 907 908 # octal escape characters (3 chars) 909 elif esc == ESC_BKSL and nch in string.octdigits: 910 esc = ESC_OCT1 911 912 # generic escape characters (single char) 913 elif esc == ESC_BKSL and nch in ('a', 'b', 'f', 'r', 'n', 't', 'v', '"', '\\'): 914 esc = ESC_ISTR 915 916 # invalid escape sequence 917 elif esc == ESC_BKSL: 918 raise SyntaxError('invalid escape character: ' + repr(nch)) 919 920 # normal characters, simply advance to the next character 921 elif esc in (ESC_IDLE, ESC_ISTR): 922 pass 923 924 # hexadecimal escape characters 925 elif esc in (ESC_HEX0, ESC_HEX1) and nch.lower() in string.hexdigits: 926 esc = ESC_HEX1 if esc == ESC_HEX0 else ESC_ISTR 927 928 # invalid hexadecimal character 929 elif esc in (ESC_HEX0, ESC_HEX1): 930 raise SyntaxError('invalid hexdecimal character: ' + repr(nch)) 931 932 # octal escape characters 933 elif esc in (ESC_OCT1, ESC_OCT2) and nch.lower() in string.octdigits: 934 esc = ESC_OCT2 if esc == ESC_OCT1 else ESC_ISTR 935 936 # at most 3 octal digits 937 elif esc in (ESC_OCT1, ESC_OCT2): 938 esc = ESC_ISTR 939 940 # illegal state, should not happen 941 else: 942 raise RuntimeError('illegal state: %d' % esc) 943 944 # check for the last argument 945 if pos is None: 946 return cls(cmd, args) 947 948 # add the last argument and build the command 949 args.append(vstr[pos:].strip()) 950 return cls(cmd, args) 951 952 class Expression: 953 pos: int 954 src: str 955 956 def __init__(self, src: str): 957 self.pos = 0 958 self.src = src 959 960 @property 961 def _ch(self) -> str: 962 return self.src[self.pos] 963 964 @property 965 def _eof(self) -> bool: 966 return self.pos >= len(self.src) 967 968 def _rch(self) -> str: 969 pos, self.pos = self.pos, self.pos + 1 970 return self.src[pos] 971 972 def _hex(self, ch: str) -> bool: 973 if len(ch) == 1 and ch[0] == '0': 974 return self._ch.lower() == 'x' 975 elif len(ch) <= 1 or ch[1].lower() != 'x': 976 return self._ch.isdigit() 977 else: 978 return self._ch in string.hexdigits 979 980 def _int(self, ch: str) -> Token: 981 while not self._eof and self._hex(ch): 982 ch += self._rch() 983 else: 984 if ch.lower().startswith('0x'): 985 return Token.num(int(ch, 16)) 986 elif ch[0] == '0': 987 return Token.num(int(ch, 8)) 988 else: 989 return Token.num(int(ch)) 990 991 def _name(self, ch: str) -> Token: 992 while not self._eof and (self._ch == '_' or self._ch.isalnum()): 993 ch += self._rch() 994 else: 995 return Token.name(ch) 996 997 def _read(self, ch: str) -> Token: 998 if ch.isdigit(): 999 return self._int(ch) 1000 elif ch.isidentifier(): 1001 return self._name(ch) 1002 elif ch in ('*', '<', '>') and not self._eof and self._ch == ch: 1003 return Token.punc(self._rch() * 2) 1004 elif ch in ('+', '-', '*', '/', '%', '&', '|', '^', '~', '(', ')'): 1005 return Token.punc(ch) 1006 else: 1007 raise SyntaxError('invalid character: ' + repr(ch)) 1008 1009 def _peek(self) -> Optional[Token]: 1010 pos = self.pos 1011 ret = self._next() 1012 self.pos = pos 1013 return ret 1014 1015 def _next(self) -> Optional[Token]: 1016 while not self._eof and self._ch.isspace(): 1017 self.pos += 1 1018 else: 1019 return Token.end() if self._eof else self._read(self._rch()) 1020 1021 def _grab(self, tk: Token, getvalue: Callable[[str], int]) -> int: 1022 if tk.tag == TOKEN_NUM: 1023 return tk.val 1024 elif tk.tag == TOKEN_NAME: 1025 return getvalue(tk.val) 1026 else: 1027 raise SyntaxError('integer or identifier expected, got ' + repr(tk)) 1028 1029 __pred__ = [ 1030 {'<<', '>>'}, 1031 {'|'}, 1032 {'^'}, 1033 {'&'}, 1034 {'+', '-'}, 1035 {'*', '/', '%'}, 1036 {'**'}, 1037 ] 1038 1039 __binary__ = { 1040 '+' : lambda a, b: a + b, 1041 '-' : lambda a, b: a - b, 1042 '*' : lambda a, b: a * b, 1043 '/' : lambda a, b: a / b, 1044 '%' : lambda a, b: a % b, 1045 '&' : lambda a, b: a & b, 1046 '^' : lambda a, b: a ^ b, 1047 '|' : lambda a, b: a | b, 1048 '<<' : lambda a, b: a << b, 1049 '>>' : lambda a, b: a >> b, 1050 '**' : lambda a, b: a ** b, 1051 } 1052 1053 def _eval(self, op: str, v1: int, v2: int) -> int: 1054 return self.__binary__[op](v1, v2) 1055 1056 def _nest(self, nest: int, getvalue: Callable[[str], int]) -> int: 1057 ret = self._expr(0, nest + 1, getvalue) 1058 ntk = self._next() 1059 1060 # it must follows with a ')' operator 1061 if ntk.tag != TOKEN_PUNC or ntk.val != ')': 1062 raise SyntaxError('")" expected, got ' + repr(ntk)) 1063 else: 1064 return ret 1065 1066 def _unit(self, nest: int, getvalue: Callable[[str], int]) -> int: 1067 tk = self._next() 1068 tt, tv = tk.tag, tk.val 1069 1070 # check for unary operators 1071 if tt == TOKEN_NUM: 1072 return tv 1073 elif tt == TOKEN_NAME: 1074 return getvalue(tv) 1075 elif tt == TOKEN_PUNC and tv == '(': 1076 return self._nest(nest, getvalue) 1077 elif tt == TOKEN_PUNC and tv == '+': 1078 return self._unit(nest, getvalue) 1079 elif tt == TOKEN_PUNC and tv == '-': 1080 return -self._unit(nest, getvalue) 1081 elif tt == TOKEN_PUNC and tv == '~': 1082 return ~self._unit(nest, getvalue) 1083 else: 1084 raise SyntaxError('integer, unary operator or nested expression expected, got ' + repr(tk)) 1085 1086 def _term(self, pred: int, nest: int, getvalue: Callable[[str], int]) -> int: 1087 lv = self._expr(pred + 1, nest, getvalue) 1088 tk = self._peek() 1089 1090 # scan to the end 1091 while True: 1092 tt = tk.tag 1093 tv = tk.val 1094 1095 # encountered EOF 1096 if tt == TOKEN_END: 1097 return lv 1098 1099 # must be an operator here 1100 if tt != TOKEN_PUNC: 1101 raise SyntaxError('operator expected, got ' + repr(tk)) 1102 1103 # check for the operator precedence 1104 if tv not in self.__pred__[pred]: 1105 return lv 1106 1107 # apply the operator 1108 op = self._next().val 1109 rv = self._expr(pred + 1, nest, getvalue) 1110 lv = self._eval(op, lv, rv) 1111 tk = self._peek() 1112 1113 def _expr(self, pred: int, nest: int, getvalue: Callable[[str], int]) -> int: 1114 if pred >= len(self.__pred__): 1115 return self._unit(nest, getvalue) 1116 else: 1117 return self._term(pred, nest, getvalue) 1118 1119 def eval(self, getvalue: Callable[[str], int]) -> int: 1120 return self._expr(0, 0, getvalue) 1121 1122 1123 class Instr: 1124 ALIGN_WIDTH = 48 1125 len : int = NotImplemented 1126 instr : Union[str, Instruction] = NotImplemented 1127 1128 def size(self, _: int) -> int: 1129 return self.len 1130 1131 def formatted(self, pc: int) -> str: 1132 raise NotImplementedError 1133 1134 @staticmethod 1135 def raw_formatted(bs: bytes, comm: str, pc: int) -> str: 1136 t = '\t' 1137 if bs: 1138 for b in bs: 1139 t +='0x%02x, ' % b 1140 # if len(bs)<Instr.ALIGN_WIDTH: 1141 # t += '\b' * (Instr.ALIGN_WIDTH - len(bs)) 1142 return '%s//%s%s' % (t, ('0x%08x ' % pc) if pc else ' ', comm) 1143 1144 class RawInstr(Instr): 1145 bs: bytes 1146 def __init__(self, size: int, instr: str, bs: bytes): 1147 self.len = size 1148 self.instr = instr 1149 self.bs = bs 1150 1151 def formatted(self, _: int) -> str: 1152 return '\t' + self.instr 1153 1154 def raw_formatted(self, pc: int) -> str: 1155 return Instr.raw_formatted(self.bs, self.instr, pc) 1156 1157 class IntInstr(Instr): 1158 comm: str 1159 func: Callable[[], int] 1160 1161 def __init__(self, size: int, func: Callable[[], int], comments: str = ''): 1162 self.len = size 1163 self.func = func 1164 self.comm = comments 1165 1166 @property 1167 def raw_bytes(self): 1168 return self.func().to_bytes(self.len, 'little') 1169 1170 @property 1171 def instr(self) -> str: 1172 return Instruction.encode(self.func().to_bytes(self.len, 'little'), self.comm) 1173 1174 def formatted(self, _: int) -> str: 1175 return '\t' + self.instr 1176 1177 def raw_formatted(self, pc: int) -> str: 1178 return Instr.raw_formatted(self.func().to_bytes(self.len, 'little'), self.comm, pc) 1179 1180 class X86Instr(Instr): 1181 def __init__(self, instr: Instruction): 1182 self.len = instr.size 1183 self.instr = instr 1184 1185 def resize(self, size: int) -> int: 1186 self.len = size 1187 return size 1188 1189 def formatted(self, _: int) -> str: 1190 return '\t' + self.instr.encoded 1191 1192 def raw_formatted(self, pc: int) -> str: 1193 return Instr.raw_formatted(self.instr._raw_instr(), str(self.instr), pc) 1194 1195 class LabelInstr(Instr): 1196 def __init__(self, name: str): 1197 self.len = 0 1198 self.instr = name 1199 1200 def formatted(self, _: int) -> str: 1201 if self.instr.isidentifier(): 1202 return self.instr + ':' 1203 else: 1204 return '_LB_%08x: // %s' % (hash(self.instr) & 0xffffffff, self.instr) 1205 1206 def raw_formatted(self, pc: int) -> str: 1207 return Instr.raw_formatted(None, str(self.instr), pc) 1208 1209 class BranchInstr(Instr): 1210 def __init__(self, instr: Instruction): 1211 self.len = instr.size 1212 self.instr = instr 1213 1214 def formatted(self, _: int) -> str: 1215 return '\t' + self.instr.encoded 1216 1217 def raw_formatted(self, pc: int) -> str: 1218 return Instr.raw_formatted(self.instr._raw_instr(), str(self.instr), pc) 1219 1220 class CommentInstr(Instr): 1221 def __init__(self, text: str): 1222 self.len = 0 1223 self.instr = '// ' + text 1224 1225 def formatted(self, _: int) -> str: 1226 return '\t' + self.instr 1227 1228 def raw_formatted(self, pc: int) -> str: 1229 return Instr.raw_formatted(None, str(self.instr), None) 1230 1231 class AlignmentInstr(Instr): 1232 bits: int 1233 fill: int 1234 1235 def __init__(self, bits: int, fill: int = 0): 1236 self.bits = bits 1237 self.fill = fill 1238 1239 def size(self, pc: int) -> int: 1240 mask = (1 << self.bits) - 1 1241 return (mask - (pc & mask) + 1) & mask 1242 1243 def formatted(self, pc: int) -> str: 1244 buf = bytes([self.fill]) * self.size(pc) 1245 return '\t' + Instruction.encode(buf, '.p2align %d, 0x%02x' % (self.bits, self.fill)) 1246 1247 def raw_formatted(self, pc: int) -> str: 1248 buf = bytes([self.fill]) * self.size(pc) 1249 return Instr.raw_formatted(buf, '.p2align %d, 0x%02x' % (self.bits, self.fill), pc) 1250 1251 REG_MAP = { 1252 'x0' : ('MOVD' , 'R0'), 1253 'x1' : ('MOVD' , 'R1'), 1254 'x2' : ('MOVD' , 'R2'), 1255 'x3' : ('MOVD' , 'R3'), 1256 'x4' : ('MOVD' , 'R4'), 1257 'x5' : ('MOVD' , 'R5'), 1258 'x6' : ('MOVD' , 'R6'), 1259 'x7' : ('MOVD' , 'R7'), 1260 'd0' : ('FMOVD' , 'F0'), 1261 'd1' : ('FMOVD' , 'F1'), 1262 'd2' : ('FMOVD' , 'F2'), 1263 'd3' : ('FMOVD' , 'F3'), 1264 'd4' : ('FMOVD' , 'F4'), 1265 'd5' : ('FMOVD' , 'F5'), 1266 'd6' : ('FMOVD' , 'F6'), 1267 'd7' : ('FMOVD' , 'F7'), 1268 } 1269 1270 class Counter: 1271 value: int = 0 1272 1273 @classmethod 1274 def next(cls) -> int: 1275 val, cls.value = cls.value, cls.value + 1 1276 return val 1277 1278 class BasicBlock: 1279 maxsp: int 1280 name: str 1281 weak: bool 1282 jmptab: bool 1283 func: bool 1284 body: List[Instr] 1285 prevs: List['BasicBlock'] 1286 next: Optional['BasicBlock'] 1287 jump: Optional['BasicBlock'] 1288 1289 def __init__(self, name: str, weak: bool = True, jmptab: bool = False, func: bool = False): 1290 self.maxsp = -1 1291 self.body = [] 1292 self.prevs = [] 1293 self.name = name 1294 self.weak = weak 1295 self.next = None 1296 self.jump = None 1297 self.jmptab = jmptab 1298 self.func = func 1299 1300 def __repr__(self): 1301 return '{BasicBlock %s}' % repr(self.name) 1302 1303 @property 1304 def last(self) -> Optional[Instr]: 1305 return next((v for v in reversed(self.body) if not isinstance(v, CommentInstr)), None) 1306 1307 def if_all_IntInstr_then_2_RawInstr(self): 1308 is_table = False 1309 instr_size = 0 1310 for instr in self.body: 1311 if isinstance(instr, IntInstr): 1312 if not is_table: 1313 instr_size = instr.len 1314 is_table = True 1315 if instr_size != instr.len: 1316 instr_size = 0 1317 continue 1318 if isinstance(instr, AlignmentInstr): 1319 continue 1320 if isinstance(instr, LabelInstr): 1321 continue 1322 # others 1323 return 1324 1325 if not is_table: 1326 return 1327 1328 # .long or .quad 1329 if instr_size == 8 or instr_size == 4: 1330 return 1331 1332 # All instrs are IntInstr, golang asm only suuport WORD and DWORD for arm. We need 1333 # combine them as 4-bytes RawInstr and align block 1334 nb = [] # new body 1335 raw_buf = []; 1336 comment = '' 1337 1338 # first element is LabelInstr 1339 for i in range(1, len(self.body)): 1340 if isinstance(self.body[i], AlignmentInstr): 1341 if i != len(self.body) -1: 1342 raise RuntimeError("Not support p2algin in : %s" % self.name) 1343 continue 1344 1345 raw_buf += self.body[i].raw_bytes 1346 comment += '// ' + self.body[i].comm + '\n' 1347 1348 align_size = len(raw_buf) % 4 1349 if align_size != 0: 1350 raw_buf += int(0).to_bytes(4 - align_size, 'little') 1351 1352 if isinstance(self.body[0], LabelInstr): 1353 nb.append(self.body[0]) 1354 1355 for i in range(0, len(raw_buf), 4): 1356 buf = raw_buf[i: i + 4] 1357 nb.append(RawInstr(len(buf), Instruction.encode(buf), buf)) 1358 1359 nb.append(CommentInstr(comment)) 1360 1361 if isinstance(self.body[-1:-1], AlignmentInstr): 1362 nb.append(self.body[-1:-1]) 1363 self.body = nb 1364 1365 def size_of(self, pc: int) -> int: 1366 return functools.reduce(lambda p, v: p + v.size(pc + p), self.body, 0) 1367 1368 def link_to(self, block: 'BasicBlock'): 1369 self.next = block 1370 block.prevs.append(self) 1371 1372 def jump_to(self, block: 'BasicBlock'): 1373 self.jump = block 1374 block.prevs.append(self) 1375 1376 @classmethod 1377 def annonymous(cls) -> 'BasicBlock': 1378 return cls('// bb.%d' % Counter.next(), weak = False) 1379 1380 CLANG_JUMPTABLE_LABLE = 'LJTI' 1381 1382 class CodeSection: 1383 dead : bool 1384 export : bool 1385 blocks : List[BasicBlock] 1386 labels : Dict[str, BasicBlock] 1387 jmptabs: Dict[str, List[BasicBlock]] 1388 funcs : Dict[str, Pcsp] 1389 bsmap_ : Dict[str, int] 1390 1391 def __init__(self): 1392 self.dead = False 1393 self.labels = {} 1394 self.export = False 1395 self.blocks = [BasicBlock.annonymous()] 1396 self.jmptabs = {} 1397 self.funcs = {} 1398 self.bsmap_ = {} 1399 1400 @classmethod 1401 def _dfs_jump_first(cls, bb: BasicBlock, visited: Dict[BasicBlock, bool], hook: Callable[[BasicBlock], bool]) -> bool: 1402 if bb not in visited or not visited[bb]: 1403 visited[bb] = True 1404 if bb.jump and not cls._dfs_jump_first(bb.jump, visited, hook): 1405 return False 1406 if bb.next and not cls._dfs_jump_first(bb.next, visited, hook): 1407 return False 1408 return hook(bb) 1409 else: 1410 return True 1411 1412 def get_jmptab(self, name: str) -> List[BasicBlock]: 1413 return self.jmptabs.setdefault(name, []) 1414 1415 def get_block(self, name: str) -> BasicBlock: 1416 for block in self.blocks: 1417 if block.name == name: 1418 return block 1419 1420 @property 1421 def block(self) -> BasicBlock: 1422 return self.blocks[-1] 1423 1424 @property 1425 def instrs(self) -> Iterable[Instr]: 1426 for block in self.blocks: 1427 yield from block.body 1428 1429 def _make(self, name: str, jmptab: bool = False, func: bool = False): 1430 if func: 1431 #NOTICE: if it is a function, always set func to be True 1432 if (old := self.labels.get(name)) and (old.func != func): 1433 old.func = True 1434 return self.labels.setdefault(name, BasicBlock(name, jmptab = jmptab, func = func)) 1435 1436 def _next(self, link: BasicBlock): 1437 if self.dead: 1438 self.dead = False 1439 else: 1440 self.block.link_to(link) 1441 1442 def _decl(self, name: str, block: BasicBlock): 1443 block.weak = False 1444 block.body.append(LabelInstr(name)) 1445 self._next(block) 1446 self.blocks.append(block) 1447 1448 def _kill(self, name: str): 1449 self.dead = True 1450 self.block.link_to(self._make(name)) 1451 1452 def _split(self, jmp: BasicBlock): 1453 self.jump = True 1454 link = BasicBlock.annonymous() 1455 self.labels[link.name] = link 1456 self.block.link_to(link) 1457 self.block.jump_to(jmp) 1458 self.blocks.append(link) 1459 1460 @staticmethod 1461 def _mk_align(v: int) -> int: 1462 if v & 15 == 0: 1463 return v 1464 else: 1465 print('* warning: SP is not aligned with 16 bytes.', file = sys.stderr) 1466 return (v + 15) & -16 1467 1468 @staticmethod 1469 def _is_spadj(ins: Instruction) -> bool: 1470 return len(ins.instr.operands) == 3 and \ 1471 isinstance(ins.instr.operands[1], mcasm.mc.Register) and \ 1472 isinstance(ins.instr.operands[2], int) and \ 1473 ins.instr.operands[1].name == 'RSP' 1474 1475 @staticmethod 1476 def _is_spmove(ins: Instruction, i: int) -> bool: 1477 return len(ins.operands) == 2 and \ 1478 isinstance(ins.operands[0], Register) and \ 1479 isinstance(ins.operands[1], Register) and \ 1480 ins.operands[i].reg == 'rsp' 1481 1482 @staticmethod 1483 def _is_rjump(ins: Optional[Instr]) -> bool: 1484 return isinstance(ins, X86Instr) and ins.instr.is_branch_label 1485 1486 def _find_label(self, name: str, adjs: Iterable[int], size: int = 0) -> int: 1487 for adj, block in zip(adjs, self.blocks): 1488 if block.name == name: 1489 return size 1490 else: 1491 # find block size from cache 1492 v = self.bsmap_.get(block.name) 1493 if v is not None: 1494 size += v + adj 1495 else: 1496 block_size = block.size_of(size) 1497 size += block_size + adj 1498 self.bsmap_[block.name] = block_size 1499 else: 1500 raise SyntaxError('unresolved reference to name: ' + name) 1501 1502 def _alloc_instr(self, instr: Instruction): 1503 if not instr.is_branch_label: 1504 self.block.body.append(X86Instr(instr)) 1505 else: 1506 self.block.body.append(BranchInstr(instr)) 1507 1508 # it seems to not be able to specify stack aligment inside the Go ASM so we 1509 # need to replace the aligned instructions with unaligned one if either of it's 1510 # operand is an RBP relative addressing memory operand 1511 1512 def _check_align(self, instr: Instruction) -> bool: 1513 # TODO: check 1514 return False 1515 1516 def _check_split(self, instr: Instruction): 1517 if instr.is_return: 1518 self.dead = True 1519 1520 elif instr.is_jmpq: # jmpq 1521 # backtrace jump table from current block (BFS) 1522 prevs = [self.block] 1523 visited = set() 1524 while len(prevs) > 0: 1525 curb = prevs.pop() 1526 if curb in visited: 1527 continue 1528 else: 1529 visited.add(curb) 1530 1531 # backtrace instructions 1532 for ins in reversed(curb.body): 1533 if isinstance(ins, X86Instr) and ins.instr.jmptab: 1534 self._split(self._make(ins.instr.jmptab, jmptab = True)) 1535 return 1536 1537 if curb.prevs: 1538 prevs.extend(curb.prevs) 1539 1540 elif instr.is_branch_label: 1541 if instr.is_jmp: 1542 self._kill(instr.label_name) 1543 1544 elif instr.is_invoke: # call 1545 fname = instr.label_name 1546 self._split(self._make(fname, func = True)) 1547 1548 else: # jeq, ja, jae ... 1549 self._split(self._make(instr.label_name)) 1550 1551 def _trace_block(self, bb: BasicBlock, pcsp: Optional[Pcsp]) -> int: 1552 if (pcsp is not None): 1553 if bb.name in self.funcs: 1554 # already traced 1555 pcsp = None 1556 else: 1557 # continue tracing, update the pcsp 1558 # NOTICE: must mark pcsp at block entry because go only calculate delta value 1559 pcsp.pc = self.get(bb.name) 1560 if bb.func or pcsp.pc < pcsp.entry: 1561 # new func 1562 pcsp = Pcsp(pcsp.pc) 1563 self.funcs[bb.name] = pcsp 1564 1565 if bb.maxsp == -1: 1566 ret = self._trace_nocache(bb, pcsp) 1567 return ret 1568 elif bb.maxsp >= 0: 1569 return bb.maxsp 1570 else: 1571 return 0 1572 1573 def _trace_nocache(self, bb: BasicBlock, pcsp: Optional[Pcsp]) -> int: 1574 bb.maxsp = -2 1575 1576 # ## FIXME: 1577 # if pcsp is None: 1578 # pcsp = Pcsp(0) 1579 1580 # make a fake object just for reducing redundant checking 1581 if pcsp: 1582 pc0, sp0 = pcsp.pc, pcsp.sp 1583 1584 maxsp, term = self._trace_instructions(bb, pcsp) 1585 1586 # this is a terminating block 1587 if term: 1588 return maxsp 1589 1590 # don't trace it's next block if it's an unconditional jump 1591 a, b = 0, 0 1592 if pcsp: 1593 pc, sp = pcsp.pc, pcsp.sp 1594 1595 if bb.jump: 1596 if bb.jump.jmptab: 1597 cases = self.get_jmptab(bb.jump.name) 1598 for case in cases: 1599 nsp = self._trace_block(case, pcsp) 1600 if pcsp: 1601 pcsp.pc, pcsp.sp = pc, sp 1602 if nsp > a: 1603 a = nsp 1604 else: 1605 a = self._trace_block(bb.jump, pcsp) 1606 if pcsp: 1607 pcsp.pc, pcsp.sp = pc, sp 1608 1609 if bb.next: 1610 b = self._trace_block(bb.next, pcsp) 1611 1612 if pcsp: 1613 pcsp.pc, pcsp.sp = pc0, sp0 1614 1615 # select the maximum stack depth 1616 bb.maxsp = maxsp + max(a, b) 1617 return bb.maxsp 1618 1619 def _trace_instructions(self, bb: BasicBlock, pcsp: Pcsp) -> Tuple[int, bool]: 1620 cursp = 0 1621 maxsp = 0 1622 close = False 1623 1624 # scan every instruction 1625 for ins in bb.body: 1626 diff = 0 1627 1628 if isinstance(ins, X86Instr): 1629 name = ins.instr.mnemonic 1630 operands = ins.instr.instr.operands 1631 1632 # check for instructions 1633 if name == 'ret': 1634 close = True 1635 elif isinstance(operands[0], mcasm.mc.Register) and operands[0].name == 'SP': 1636 # print(ins.instr.asm_code) 1637 if name == 'add': 1638 diff = -self._mk_align(operands[2]) 1639 elif name == 'sub': 1640 diff = self._mk_align(operands[2]) 1641 elif name == 'stp': 1642 diff = -self._mk_align(operands[4] * 8) 1643 elif name == 'ldp': 1644 diff = -self._mk_align(operands[4] * 8) 1645 elif name == 'str': 1646 diff = -self._mk_align(operands[3]) 1647 else: 1648 raise RuntimeError("An instruction adjsut sp but bot processed: %s" % ins.instr.asm_code) 1649 1650 cursp += diff 1651 1652 # update the max stack depth 1653 if cursp > maxsp: 1654 maxsp = cursp 1655 1656 # update pcsp 1657 if pcsp: 1658 pcsp.update(ins.size(pcsp.pc), diff) 1659 1660 # trace successful 1661 return maxsp, close 1662 1663 def get(self, key: str) -> Optional[int]: 1664 if key not in self.labels: 1665 raise SyntaxError('unresolved reference to name: %s' % key) 1666 else: 1667 return self._find_label(key, itertools.repeat(0, len(self.blocks))) 1668 1669 def has(self, key: str) -> bool: 1670 return key in self.labels 1671 1672 def emit(self, buf: bytes, comments: str = ''): 1673 if not self.dead: 1674 self.block.body.append(RawInstr(len(buf), Instruction.encode(buf, comments or buf.hex()), buf)) 1675 1676 def lazy(self, size: int, func: Callable[[], int], comments: str = ''): 1677 if not self.dead: 1678 self.block.body.append(IntInstr(size, func, comments)) 1679 1680 def label(self, name: str): 1681 if name not in self.labels or self.labels[name].weak: 1682 self._decl(name, self._make(name)) 1683 else: 1684 raise SyntaxError('duplicated label: ' + name) 1685 1686 def instr(self, instr: Instruction): 1687 if not self.dead: 1688 if self._check_align(instr): 1689 return 1690 self._alloc_instr(instr) 1691 self._check_split(instr) 1692 1693 # @functools.cache 1694 def stacksize(self, name: str) -> int: 1695 if name not in self.labels: 1696 raise SyntaxError('undefined function: ' + name) 1697 else: 1698 return self._trace_block(self.labels[name], None) 1699 1700 # @functools.cache 1701 def pcsp(self, name: str, entry: int) -> int: 1702 if name not in self.labels: 1703 raise SyntaxError('undefined function: ' + name) 1704 else: 1705 pcsp = Pcsp(entry) 1706 self.labels[name].func = True 1707 return self._trace_block(self.labels[name], pcsp) 1708 1709 def debug(self, pos: int, inss: List[Instruction]): 1710 def inject(bb: BasicBlock) -> bool: 1711 if (not bb.func) and (bb.name not in self.funcs): 1712 return True 1713 nonlocal pos 1714 if pos >= len(bb.body): 1715 return 1716 for ins in inss: 1717 bb.body.insert(pos, ins) 1718 pos += 1 1719 visited = {} 1720 for _, bb in self.labels.items(): 1721 CodeSection._dfs_jump_first(bb, visited, inject) 1722 def debug(self): 1723 for label, bb in self.labels.items(): 1724 print(label) 1725 for v in bb.body: 1726 if isinstance(v, (X86Instr, BranchInstr)): 1727 print(v.instr.asm_code) 1728 1729 STUB_NAME = '__native_entry__' 1730 STUB_SIZE = 67 1731 WITH_OFFS = os.getenv('ASM2ASM_DEBUG_OFFSET', '').lower() in ('1', 'yes', 'true') 1732 1733 class Assembler: 1734 out : List[str] 1735 subr : Dict[str, int] 1736 code : CodeSection 1737 vals : Dict[str, Union[str, int]] 1738 1739 def __init__(self): 1740 self.out = [] 1741 self.subr = {} 1742 self.vals = {} 1743 self.code = CodeSection() 1744 1745 def _get(self, v: str) -> int: 1746 if v not in self.vals: 1747 return self.code.get(v) 1748 elif isinstance(self.vals[v], int): 1749 return self.vals[v] 1750 else: 1751 ret = self.vals[v] = self._eval(self.vals[v]) 1752 return ret 1753 1754 def _eval(self, v: str) -> int: 1755 return Expression(v).eval(self._get) 1756 1757 def _emit(self, v: bytes, cmd: str): 1758 align_size = len(v) % 4 1759 if align_size != 0: 1760 v += int(0).to_bytes(4 - align_size, 'little') 1761 1762 for i in range(0, len(v), 4): 1763 self.code.emit(v[i:i + 4], '%s %d, %s' % (cmd, len(v[i:i + 4]), repr(v[i:i + 16])[1:])) 1764 1765 def _limit(self, v: int, a: int, b: int) -> int: 1766 if not (a <= v <= b): 1767 raise SyntaxError('integer constant out of bound [%d, %d): %d' % (a, b, v)) 1768 else: 1769 return v 1770 1771 def _vfill(self, cmd: str, args: List[str]) -> Tuple[int, int]: 1772 if len(args) == 1: 1773 return self._limit(self._eval(args[0]), 1, 1 << 64), 0 1774 elif len(args) == 2: 1775 return self._limit(self._eval(args[0]), 1, 1 << 64), self._limit(self._eval(args[1]), 0, 255) 1776 else: 1777 raise SyntaxError(cmd + ' takes 1 ~ 2 arguments') 1778 1779 def _bytes(self, cmd: str, args: List[str], low: int, high: int, size: int): 1780 if len(args) != 1: 1781 raise SyntaxError(cmd + ' takes exact 1 argument') 1782 else: 1783 self.code.lazy(size, lambda: self._limit(self._eval(args[0]), low, high) & high, '%s %s' % (cmd, args[0])) 1784 1785 def _comment(self, msg: str): 1786 self.code.blocks[-1].body.append(CommentInstr(msg)) 1787 1788 def _cmd_nop(self, _: List[str]): 1789 pass 1790 1791 def _cmd_set(self, args: List[str]): 1792 if len(args) != 2: 1793 raise SyntaxError(".set takes exact 2 argument") 1794 elif not args[0].isidentifier(): 1795 raise SyntaxError(repr(args[0]) + " is not a valid identifier") 1796 else: 1797 key = args[0] 1798 val = args[1] 1799 self.vals[key] = val 1800 self._comment('.set ' + ', '.join(args)) 1801 # special case: clang-generated jump tables are always like '{block}_{table}' 1802 jt = val.find(CLANG_JUMPTABLE_LABLE) 1803 if jt > 0: 1804 tab = self.code.get_jmptab(val[jt:]) 1805 tab.append(self.code.get_block(val[:jt-1])) 1806 1807 def _cmd_byte(self, args: List[str]): 1808 self._bytes('.byte', args, -0x80, 0xff, 1) 1809 1810 def _cmd_word(self, args: List[str]): 1811 self._bytes('.word', args, -0x8000, 0xffff, 2) 1812 1813 def _cmd_long(self, args: List[str]): 1814 self._bytes('.long', args, -0x80000000, 0xffffffff, 4) 1815 1816 def _cmd_quad(self, args: List[str]): 1817 self._bytes('.quad', args, -0x8000000000000000, 0xffffffffffffffff, 8) 1818 1819 def _cmd_ascii(self, args: List[str]): 1820 if len(args) != 1: 1821 raise SyntaxError('.ascii takes exact 1 argument') 1822 else: 1823 self._emit(args[0].encode('latin-1'), '.ascii') 1824 1825 def _cmd_asciz(self, args: List[str]): 1826 if len(args) != 1: 1827 raise SyntaxError('.asciz takes exact 1 argument') 1828 else: 1829 self._emit(args[0].encode('latin-1') + b'\0', '.asciz') 1830 1831 def _cmd_space(self, args: List[str]): 1832 nb, fv = self._vfill('.space', args) 1833 self._emit(bytes([fv] * nb), '.space') 1834 1835 def _cmd_p2align(self, args: List[str]): 1836 if len(args) == 1: 1837 self.code.block.body.append(AlignmentInstr(self._eval(args[0]))) 1838 elif len(args) == 2: 1839 self.code.block.body.append(AlignmentInstr(self._eval(args[0]), self._eval(args[1]))) 1840 else: 1841 raise SyntaxError('.p2align takes 1 ~ 2 arguments') 1842 1843 @functools.cached_property 1844 def _commands(self) -> dict: 1845 return { 1846 '.set' : self._cmd_set, 1847 '.int' : self._cmd_long, 1848 '.long' : self._cmd_long, 1849 '.byte' : self._cmd_byte, 1850 '.quad' : self._cmd_quad, 1851 '.word' : self._cmd_word, 1852 '.hword' : self._cmd_word, 1853 '.short' : self._cmd_word, 1854 '.ascii' : self._cmd_ascii, 1855 '.asciz' : self._cmd_asciz, 1856 '.space' : self._cmd_space, 1857 '.globl' : self._cmd_nop, 1858 '.text' : self._cmd_nop, 1859 '.file' : self._cmd_nop, 1860 '.type' : self._cmd_nop, 1861 '.p2align' : self._cmd_p2align, 1862 '.align' : self._cmd_nop, 1863 '.size' : self._cmd_nop, 1864 '.section' : self._cmd_nop, 1865 '.loh' : self._cmd_nop, 1866 '.data_region' : self._cmd_nop, 1867 '.build_version' : self._cmd_nop, 1868 '.end_data_region' : self._cmd_nop, 1869 '.subsections_via_symbols' : self._cmd_nop, 1870 # linux-gnu 1871 '.xword' :self._cmd_nop, 1872 } 1873 1874 @staticmethod 1875 def _is_rip_relative(op: Operand) -> bool: 1876 return isinstance(op, Memory) and \ 1877 op.base is not None and \ 1878 op.base.reg == 'rip' and \ 1879 op.index is None and \ 1880 isinstance(op.disp, Reference) 1881 1882 @staticmethod 1883 def _remove_comments(line: str, *, st: str = 'normal') -> str: 1884 for i, ch in enumerate(line): 1885 if st == 'normal' and ch == '/' : st = 'slcomm' 1886 elif st == 'normal' and ch == '\"' : st = 'string' 1887 # elif st == 'normal' and ch in ('#', ';') : return line[:i] 1888 elif st == 'normal' and ch in (';') : return line[:i] 1889 elif st == 'slcomm' and ch == '/' : return line[:i - 1] 1890 elif st == 'slcomm' : st = 'normal' 1891 elif st == 'string' and ch == '\"' : st = 'normal' 1892 elif st == 'string' and ch == '\\' : st = 'escape' 1893 elif st == 'escape' : st = 'string' 1894 else: 1895 return line 1896 1897 @staticmethod 1898 def _replace_adrp_line(line: str) -> str: 1899 if 'adrp' in line: 1900 line = line.replace('adrp', 'adr').replace('@PAGE', '') 1901 return line 1902 1903 @staticmethod 1904 def _replace_adrp(src: List[str]) -> List[str]: 1905 back_label_count = 0 1906 adrp_label_map = {} 1907 new_src = [] 1908 for line in src: 1909 line = Assembler._remove_comments(line) 1910 line = line.strip() 1911 1912 if not line: 1913 continue 1914 # is instructions 1915 if line[-1] != ':' and line[0] != '.': 1916 instr = Instruction(line, back_label_count) 1917 if instr.ADRP_label: 1918 back_label_count += 1 1919 new_src.append(instr.asm_code) 1920 new_src.append(instr.back_label + ':') 1921 if instr.text_label in adrp_label_map: 1922 adrp_label_map[instr.text_label] += [(instr.ADRP_label, instr.ADR_instr, instr.back_label)] 1923 else: 1924 adrp_label_map[instr.text_label] = [(instr.ADRP_label, instr.ADR_instr, instr.back_label)] 1925 else: 1926 new_src.append(line) 1927 else: 1928 new_src.append(line) 1929 1930 nn_src = [] 1931 1932 for line in new_src: 1933 if line[-1] == ':': # is label 1934 if line[:-1] in adrp_label_map: 1935 for item in adrp_label_map[line[:-1]]: 1936 nn_src.append(item[0] + ':') # label that adrp will jump to 1937 nn_src.append(item[1]) # adr to get really symbol address 1938 nn_src.append('b ' + item[2]) # jump back to adrp next instruction 1939 nn_src.append(line) 1940 1941 return nn_src 1942 1943 def _parse(self, src: List[str]): 1944 # src = self._replace_adrp(o_src) 1945 1946 for line in src: 1947 line = Assembler._remove_comments(line) 1948 line = line.strip() 1949 1950 # skip empty lines 1951 if not line: 1952 continue 1953 1954 # labels, resolve the offset 1955 if line[-1] == ':': 1956 self.code.label(line[:-1]) 1957 continue 1958 1959 # instructions 1960 if line[0] != '.': 1961 line = self._replace_adrp_line(line) 1962 self.code.instr(Instruction(line, 0)) 1963 continue 1964 1965 # parse the command 1966 cmd = Command.parse(line) 1967 func = self._commands.get(cmd.cmd) 1968 1969 # handle the command 1970 if func is not None: 1971 func(cmd.args) 1972 else: 1973 raise SyntaxError('invalid assembly command: ' + cmd.cmd) 1974 1975 def _reloc(self, rip: int = 0): 1976 for block in self.code.blocks: 1977 for instr in block.body: 1978 rip += self._reloc_one(instr, rip) 1979 1980 def _reloc_one(self, instr: Instr, rip: int) -> int: 1981 if not isinstance(instr, (X86Instr, BranchInstr)): 1982 return instr.size(rip) 1983 elif instr.instr.need_reloc: 1984 return self._reloc_branch(instr.instr, rip) 1985 else: 1986 return instr.resize(self._reloc_normal(instr.instr, rip)) 1987 1988 def _reloc_branch(self, instr: Instruction, rip: int) -> int: 1989 label = instr.label_name 1990 if label is None: 1991 raise RuntimeError('cannnot found label name: %s' % instr.asm_code) 1992 if instr.mnemonic == 'adr' and label == 'Ltmp0': 1993 instr.set_label_offset(-4) 1994 else: 1995 instr.set_label_offset(self.code.get(label)- rip - instr.size) 1996 1997 return instr.size 1998 1999 def _reloc_normal(self, instr: Instruction, rip: int) -> int: 2000 if instr.need_reloc: 2001 raise SyntaxError('unresolved instruction when relocation: ' + instr.asm_code) 2002 return instr.size 2003 2004 def _LE_4bytes_IntIntr_2_RawIntr(self): 2005 for block in self.code.blocks: 2006 block.if_all_IntInstr_then_2_RawInstr() 2007 2008 def _declare(self, protos: PrototypeMap): 2009 if OUTPUT_RAW: 2010 self._declare_body_raw() 2011 else: 2012 name = next(iter(protos)) 2013 self._declare_body(name[1:]) 2014 self._declare_functions(protos) 2015 2016 def _declare_body(self, name: str): 2017 size = self.code.stacksize(name) 2018 gosize = 0 if size < 16 else size-16 2019 self.out.append('TEXT ·_%s_entry__(SB), NOSPLIT, $%d' % (name, gosize)) 2020 self.out.append('\tNO_LOCAL_POINTERS') 2021 # get current PC 2022 self.out.append('\tWORD $0x100000a0 // adr x0, .+20') 2023 # self.out.append('\t'+Instruction('add sp, sp, #%d' % size).encoded) 2024 self.out.append('\tMOVD R0, ret(FP)') 2025 self.out.append('\tRET') 2026 self._LE_4bytes_IntIntr_2_RawIntr() 2027 self._reloc() 2028 2029 # instruction buffer 2030 pc = 0 2031 ins = self.code.instrs 2032 2033 for v in ins: 2034 self.out.append(('// +%d\n' % pc if WITH_OFFS else '') + v.formatted(pc)) 2035 pc += v.size(pc) 2036 2037 def _declare_body_raw(self): 2038 self._reloc() 2039 2040 # instruction buffer 2041 pc = 0 2042 ins = self.code.instrs 2043 2044 # dump every instruction 2045 for v in ins: 2046 self.out.append(v.raw_formatted(pc)) 2047 pc += v.size(pc) 2048 2049 def _declare_function(self, name: str, proto: Prototype): 2050 offs = 0 2051 subr = name[1:] 2052 addr = self.code.get(subr) 2053 self.subr[subr] = addr 2054 size = self.code.stacksize(subr) 2055 2056 m_size = size + 64 2057 # rsp_sub_size = size + 16 2058 2059 if OUTPUT_RAW: 2060 return 2061 2062 # function header and stack checking 2063 self.out.append('') 2064 # frame size is 16 to store x29 and x30 2065 # self.out.append('TEXT ·%s(SB), NOSPLIT | NOFRAME, $0-%d' % (name, proto.argspace)) 2066 self.out.append('TEXT ·%s(SB), NOSPLIT, $%d-%d' % (name, 0, proto.argspace)) 2067 self.out.append('\tNO_LOCAL_POINTERS') 2068 2069 # add stack check if needed 2070 if m_size != 0: 2071 self.out.append('') 2072 self.out.append('_entry:') 2073 self.out.append('\tMOVD 16(g), R16') 2074 if size > 0: 2075 if size < (0x1 << 12) - 1: 2076 self.out.append('\tSUB $%d, RSP, R17' % (m_size)) 2077 elif size < (0x1 << 16) - 1: 2078 self.out.append('\tMOVD $%d, R17' % (m_size)) 2079 self.out.append('\tSUB R17, RSP, R17') 2080 else: 2081 raise RuntimeError('too large stack size: %d' % (m_size)) 2082 self.out.append('\tCMP R16, R17') 2083 else: 2084 self.out.append('\tCMP R16, RSP') 2085 self.out.append('\tBLS _stack_grow') 2086 2087 # function name 2088 self.out.append('') 2089 self.out.append('%s:' % subr) 2090 2091 # self.out.append('\tMOVD.W R30, -16(RSP)') 2092 # self.out.append('\tMOVD R29, -8(RSP)') 2093 # self.out.append('\tSUB $8, RSP, R29') 2094 2095 # intialize all the arguments 2096 for arg in proto.args: 2097 offs += arg.size 2098 op, reg = REG_MAP[arg.creg.reg] 2099 self.out.append('\t%s %s+%d(FP), %s' % (op, arg.name, offs - arg.size, reg)) 2100 2101 2102 # Go ASM completely ignores the offset of the JMP instruction, 2103 # so we need to use indirect jumps instead for tail-call elimination 2104 2105 # LEA and JUMP 2106 self.out.append('\tMOVD ·_subr_%s(SB), R11' % (subr)) 2107 self.out.append('\tWORD $0x1000005e // adr x30, .+8') 2108 self.out.append('\tJMP (R11)') 2109 # self.out.append('\tCALL ·_%s_entry__(SB) // %s' % (subr, subr)) 2110 2111 # normal functions, call the real function, and return the result 2112 if proto.retv is not None: 2113 self.out.append('\t%s, %s+%d(FP)' % (' '.join(REG_MAP[proto.retv.creg.reg]), proto.retv.name, offs)) 2114 # Restore LR and Frame Pointer 2115 # self.out.append('\tLDP -8(RSP), (R29, R30)') 2116 # self.out.append('\tADD $16, RSP') 2117 2118 self.out.append('\tRET') 2119 2120 # add stack growing if needed 2121 if m_size != 0: 2122 self.out.append('') 2123 self.out.append('_stack_grow:') 2124 self.out.append('\tMOVD R30, R3') 2125 self.out.append('\tCALL runtime·morestack_noctxt<>(SB)') 2126 self.out.append('\tJMP _entry') 2127 2128 def _declare_functions(self, protos: PrototypeMap): 2129 for name, proto in sorted(protos.items()): 2130 if name[0] == '_': 2131 self._declare_function(name, proto) 2132 else: 2133 raise SyntaxError('function prototype must have a "_" prefix: ' + repr(name)) 2134 2135 def parse(self, src: List[str], proto: PrototypeMap): 2136 # self.code.instr(Instruction('adr x0, .')) 2137 # self.code.instr(Instruction('add sp, sp, #%d'%self.code.stacksize(name))) 2138 # self.code.instr(Instruction('ret')) 2139 # cmd = Command.parse(".p2align 4") 2140 # func = self._commands.get(cmd.cmd) 2141 # func(cmd.args) 2142 2143 self._parse(src) 2144 self._declare(proto) 2145 2146 GOOS = { 2147 'aix', 2148 'android', 2149 'darwin', 2150 'dragonfly', 2151 'freebsd', 2152 'hurd', 2153 'illumos', 2154 'js', 2155 'linux', 2156 'nacl', 2157 'netbsd', 2158 'openbsd', 2159 'plan9', 2160 'solaris', 2161 'windows', 2162 'zos', 2163 } 2164 2165 GOARCH = { 2166 '386', 2167 'amd64', 2168 'amd64p32', 2169 'arm', 2170 'armbe', 2171 'arm64', 2172 'arm64be', 2173 'ppc64', 2174 'ppc64le', 2175 'mips', 2176 'mipsle', 2177 'mips64', 2178 'mips64le', 2179 'mips64p32', 2180 'mips64p32le', 2181 'ppc', 2182 'riscv', 2183 'riscv64', 2184 's390', 2185 's390x', 2186 'sparc', 2187 'sparc64', 2188 'wasm', 2189 } 2190 2191 def make_subr_filename(name: str) -> str: 2192 name = os.path.basename(name) 2193 base = os.path.splitext(name)[0].rsplit('_', 2) 2194 2195 # construct the new name 2196 if base[-1] in GOOS: 2197 return '%s_subr_%s.go' % ('_'.join(base[:-1]), base[-1]) 2198 elif base[-1] not in GOARCH: 2199 return '%s_subr.go' % '_'.join(base) 2200 elif len(base) > 2 and base[-2] in GOOS: 2201 return '%s_subr_%s_%s.go' % ('_'.join(base[:-2]), base[-2], base[-1]) 2202 else: 2203 return '%s_subr_%s.go' % ('_'.join(base[:-1]), base[-1]) 2204 2205 def parse_args(): 2206 parser = argparse.ArgumentParser(description='Convert llvm asm to golang asm.') 2207 parser.add_argument('proto_file', type=str, help = 'The go file that declares go functions') 2208 parser.add_argument('asm_file', type=str, nargs='+', help = 'The llvm assembly file') 2209 parser.add_argument('-r', default=False, action='store_true', help = 'Ture: output as raw; default is False') 2210 return parser.parse_args() 2211 2212 def main(): 2213 src = [] 2214 args = parse_args() 2215 2216 # check if optional flag is enabled 2217 global OUTPUT_RAW 2218 OUTPUT_RAW = False 2219 if args.r: 2220 OUTPUT_RAW = True 2221 2222 proto_name = os.path.splitext(args.proto_file)[0] 2223 2224 # parse the prototype 2225 with open(proto_name + '.go', 'r', newline = None) as fp: 2226 pkg, proto = PrototypeMap.parse(fp.read()) 2227 2228 # read all the sources, and combine them together 2229 for fn in args.asm_file: 2230 with open(fn, 'r', newline = None) as fp: 2231 src.extend(fp.read().splitlines()) 2232 2233 asm = Assembler() 2234 2235 # convert the original sources 2236 if OUTPUT_RAW: 2237 asm.out.append('// +build arm64') 2238 asm.out.append('// Code generated by asm2asm, DO NOT EDIT.') 2239 asm.out.append('') 2240 asm.out.append('package %s' % pkg) 2241 asm.out.append('') 2242 ## native text 2243 asm.out.append('var Text%s = []byte{' % STUB_NAME) 2244 else: 2245 asm.out.append('// +build !noasm !appengine') 2246 asm.out.append('// Code generated by asm2asm, DO NOT EDIT.') 2247 asm.out.append('') 2248 asm.out.append('#include "go_asm.h"') 2249 asm.out.append('#include "funcdata.h"') 2250 asm.out.append('#include "textflag.h"') 2251 asm.out.append('') 2252 2253 asm.parse(src, proto) 2254 2255 if OUTPUT_RAW: 2256 asrc = proto_name[:proto_name.rfind('_')] + '_text_arm.go' 2257 else: 2258 asrc = proto_name + '.s' 2259 2260 # save the converted result 2261 with open(asrc, 'w') as fp: 2262 for line in asm.out: 2263 print(line, file = fp) 2264 if OUTPUT_RAW: 2265 print('}', file = fp) 2266 2267 # calculate the subroutine stub file name 2268 subr = make_subr_filename(args.proto_file) 2269 subr = os.path.join(os.path.dirname(args.proto_file), subr) 2270 2271 # save the compiled code stub 2272 with open(subr, 'w') as fp: 2273 print('// +build !noasm !appengine', file = fp) 2274 print('// Code generated by asm2asm, DO NOT EDIT.', file = fp) 2275 print(file = fp) 2276 print('package %s' % pkg, file = fp) 2277 2278 # also save the actual function addresses if any 2279 if not asm.subr: 2280 return 2281 2282 if OUTPUT_RAW: 2283 print(file = fp) 2284 print('import (\n\t`github.com/bytedance/sonic/loader`\n)', file = fp) 2285 2286 # dump every entry for all functions 2287 print(file = fp) 2288 print('const (', file = fp) 2289 for name in asm.code.funcs.keys(): 2290 addr = asm.code.get(name) 2291 if addr is not None: 2292 print(f' _entry_{name} = %d' % addr, file = fp) 2293 print(')', file = fp) 2294 2295 # dump max stack depth for all functions 2296 print(file = fp) 2297 print('const (', file = fp) 2298 for name in asm.code.funcs.keys(): 2299 print(' _stack_%s = %d' % (name, asm.code.stacksize(name)), file = fp) 2300 print(')', file = fp) 2301 2302 # dump every text size for all functions 2303 print(file = fp) 2304 print('const (', file = fp) 2305 for name, pcsp in asm.code.funcs.items(): 2306 if pcsp is not None: 2307 # print(f'before {name} optimize {pcsp}') 2308 pcsp.optimize() 2309 # print(f'after {name} optimize {pcsp}') 2310 print(f' _size_{name} = %d' % (pcsp.maxpc - pcsp.entry), file = fp) 2311 print(')', file = fp) 2312 2313 # dump every pcsp for all functions 2314 print(file = fp) 2315 print('var (', file = fp) 2316 for name, pcsp in asm.code.funcs.items(): 2317 if pcsp is not None: 2318 print(f' _pcsp_{name} = %s' % pcsp, file = fp) 2319 print(')', file = fp) 2320 2321 # insert native entry info 2322 print(file = fp) 2323 print('var Funcs = []loader.CFunc{', file = fp) 2324 print(' {"%s", 0, %d, 0, nil},' % (STUB_NAME, STUB_SIZE), file = fp) 2325 # dump every native function info for all functions 2326 for name in asm.code.funcs.keys(): 2327 print(' {"%s", _entry_%s, _size_%s, _stack_%s, _pcsp_%s},' % (name, name, name, name, name), file = fp) 2328 print('}', file = fp) 2329 2330 else: 2331 # native entry for entry function 2332 print(file = fp) 2333 print('//go:nosplit', file = fp) 2334 print('//go:noescape', file = fp) 2335 print('//goland:noinspection ALL', file = fp) 2336 for name, entry in asm.subr.items(): 2337 print('func _%s_entry__() uintptr' % name, file = fp) 2338 2339 # dump exported function entry for exported functions 2340 print(file = fp) 2341 print('var (', file = fp) 2342 mlen = max(len(s) for s in asm.subr) 2343 for name, entry in asm.subr.items(): 2344 print(' _subr_%s uintptr = _%s_entry__() + %d' % (name.ljust(mlen, ' '), name, entry), file = fp) 2345 # print(' _subr_%s uintptr = %d' % (name.ljust(mlen, ' '), entry), file = fp) 2346 print(')', file = fp) 2347 2348 # dump max stack depth for exported functions 2349 print(file = fp) 2350 print('const (', file = fp) 2351 for name in asm.subr.keys(): 2352 print(' _stack_%s = %d' % (name, asm.code.stacksize(name)), file = fp) 2353 print(')', file = fp) 2354 2355 # assign subroutine offsets to '_' to mute the "unused" warnings 2356 print(file = fp) 2357 print('var (', file = fp) 2358 for name in asm.subr: 2359 print(' _ = _subr_%s' % name, file = fp) 2360 print(')', file = fp) 2361 2362 # dump every constant 2363 print(file = fp) 2364 print('const (', file = fp) 2365 for name in asm.subr: 2366 print(' _ = _stack_%s' % name, file = fp) 2367 else: 2368 print(')', file = fp) 2369 2370 if __name__ == '__main__': 2371 main()