github.com/decomp/exp@v0.0.0-20210624183419-6d058f5e1da6/cmd/bin2c/arg.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/token"
     7  	"log"
     8  	"strconv"
     9  
    10  	"github.com/mewkiz/pkg/errutil"
    11  	"golang.org/x/arch/x86/x86asm"
    12  )
    13  
    14  // getArg converts arg into a corresponding Go expression.
    15  func getArg(arg x86asm.Arg) ast.Expr {
    16  	switch arg := arg.(type) {
    17  	case x86asm.Reg:
    18  		return getReg(arg)
    19  	case x86asm.Mem:
    20  		return getMem(arg)
    21  	case x86asm.Imm:
    22  		return createExpr(int64(arg))
    23  	case x86asm.Rel:
    24  		// TODO: Implement support for relative addresses.
    25  	}
    26  	fmt.Printf("%#v\n", arg)
    27  	log.Fatal(errutil.Newf("support for type %T not yet implemented", arg))
    28  	panic("unreachable")
    29  }
    30  
    31  // regs maps register names to their corresponding Go identifiers.
    32  var regs = map[string]*ast.Ident{
    33  	// 8-bit
    34  	"AL":   ast.NewIdent("al"),
    35  	"CL":   ast.NewIdent("cl"),
    36  	"DL":   ast.NewIdent("dl"),
    37  	"BL":   ast.NewIdent("bl"),
    38  	"AH":   ast.NewIdent("ah"),
    39  	"CH":   ast.NewIdent("ch"),
    40  	"DH":   ast.NewIdent("dh"),
    41  	"BH":   ast.NewIdent("bh"),
    42  	"SPB":  ast.NewIdent("spb"),
    43  	"BPB":  ast.NewIdent("bpb"),
    44  	"SIB":  ast.NewIdent("sib"),
    45  	"DIB":  ast.NewIdent("dib"),
    46  	"R8B":  ast.NewIdent("r8b"),
    47  	"R9B":  ast.NewIdent("r9b"),
    48  	"R10B": ast.NewIdent("r10b"),
    49  	"R11B": ast.NewIdent("r11b"),
    50  	"R12B": ast.NewIdent("r12b"),
    51  	"R13B": ast.NewIdent("r13b"),
    52  	"R14B": ast.NewIdent("r14b"),
    53  	"R15B": ast.NewIdent("r15b"),
    54  
    55  	// 16-bit
    56  	"AX":   ast.NewIdent("ax"),
    57  	"CX":   ast.NewIdent("cx"),
    58  	"DX":   ast.NewIdent("dx"),
    59  	"BX":   ast.NewIdent("bx"),
    60  	"SP":   ast.NewIdent("sp"),
    61  	"BP":   ast.NewIdent("bp"),
    62  	"SI":   ast.NewIdent("si"),
    63  	"DI":   ast.NewIdent("di"),
    64  	"R8W":  ast.NewIdent("r8w"),
    65  	"R9W":  ast.NewIdent("r9w"),
    66  	"R10W": ast.NewIdent("r10w"),
    67  	"R11W": ast.NewIdent("r11w"),
    68  	"R12W": ast.NewIdent("r12w"),
    69  	"R13W": ast.NewIdent("r13w"),
    70  	"R14W": ast.NewIdent("r14w"),
    71  	"R15W": ast.NewIdent("r15w"),
    72  
    73  	// 32-bit
    74  	"EAX":  ast.NewIdent("eax"),
    75  	"ECX":  ast.NewIdent("ecx"),
    76  	"EDX":  ast.NewIdent("edx"),
    77  	"EBX":  ast.NewIdent("ebx"),
    78  	"ESP":  ast.NewIdent("esp"),
    79  	"EBP":  ast.NewIdent("ebp"),
    80  	"ESI":  ast.NewIdent("esi"),
    81  	"EDI":  ast.NewIdent("edi"),
    82  	"R8L":  ast.NewIdent("r8l"),
    83  	"R9L":  ast.NewIdent("r9l"),
    84  	"R10L": ast.NewIdent("r10l"),
    85  	"R11L": ast.NewIdent("r11l"),
    86  	"R12L": ast.NewIdent("r12l"),
    87  	"R13L": ast.NewIdent("r13l"),
    88  	"R14L": ast.NewIdent("r14l"),
    89  	"R15L": ast.NewIdent("r15l"),
    90  
    91  	// 64-bit
    92  	"RAX": ast.NewIdent("rax"),
    93  	"RCX": ast.NewIdent("rcx"),
    94  	"RDX": ast.NewIdent("rdx"),
    95  	"RBX": ast.NewIdent("rbx"),
    96  	"RSP": ast.NewIdent("rsp"),
    97  	"RBP": ast.NewIdent("rbp"),
    98  	"RSI": ast.NewIdent("rsi"),
    99  	"RDI": ast.NewIdent("rdi"),
   100  	"R8":  ast.NewIdent("r8"),
   101  	"R9":  ast.NewIdent("r9"),
   102  	"R10": ast.NewIdent("r10"),
   103  	"R11": ast.NewIdent("r11"),
   104  	"R12": ast.NewIdent("r12"),
   105  	"R13": ast.NewIdent("r13"),
   106  	"R14": ast.NewIdent("r14"),
   107  	"R15": ast.NewIdent("r15"),
   108  
   109  	// Instruction pointer.
   110  	"IP":  ast.NewIdent("ip"),  // 16-bit
   111  	"EIP": ast.NewIdent("eip"), // 32-bit
   112  	"RIP": ast.NewIdent("rip"), // 64-bit
   113  
   114  	// 387 floating point registers.
   115  	"F0": ast.NewIdent("f0"),
   116  	"F1": ast.NewIdent("f1"),
   117  	"F2": ast.NewIdent("f2"),
   118  	"F3": ast.NewIdent("f3"),
   119  	"F4": ast.NewIdent("f4"),
   120  	"F5": ast.NewIdent("f5"),
   121  	"F6": ast.NewIdent("f6"),
   122  	"F7": ast.NewIdent("f7"),
   123  
   124  	// MMX registers.
   125  	"M0": ast.NewIdent("m0"),
   126  	"M1": ast.NewIdent("m1"),
   127  	"M2": ast.NewIdent("m2"),
   128  	"M3": ast.NewIdent("m3"),
   129  	"M4": ast.NewIdent("m4"),
   130  	"M5": ast.NewIdent("m5"),
   131  	"M6": ast.NewIdent("m6"),
   132  	"M7": ast.NewIdent("m7"),
   133  
   134  	// XMM registers.
   135  	"X0":  ast.NewIdent("x0"),
   136  	"X1":  ast.NewIdent("x1"),
   137  	"X2":  ast.NewIdent("x2"),
   138  	"X3":  ast.NewIdent("x3"),
   139  	"X4":  ast.NewIdent("x4"),
   140  	"X5":  ast.NewIdent("x5"),
   141  	"X6":  ast.NewIdent("x6"),
   142  	"X7":  ast.NewIdent("x7"),
   143  	"X8":  ast.NewIdent("x8"),
   144  	"X9":  ast.NewIdent("x9"),
   145  	"X10": ast.NewIdent("x10"),
   146  	"X11": ast.NewIdent("x11"),
   147  	"X12": ast.NewIdent("x12"),
   148  	"X13": ast.NewIdent("x13"),
   149  	"X14": ast.NewIdent("x14"),
   150  	"X15": ast.NewIdent("x15"),
   151  
   152  	// Segment registers.
   153  	"ES": ast.NewIdent("es"),
   154  	"CS": ast.NewIdent("cs"),
   155  	"SS": ast.NewIdent("ss"),
   156  	"DS": ast.NewIdent("ds"),
   157  	"FS": ast.NewIdent("fs"),
   158  	"GS": ast.NewIdent("gs"),
   159  
   160  	// System registers.
   161  	"GDTR": ast.NewIdent("gdtr"),
   162  	"IDTR": ast.NewIdent("idtr"),
   163  	"LDTR": ast.NewIdent("ldtr"),
   164  	"MSW":  ast.NewIdent("msw"),
   165  	"TASK": ast.NewIdent("task"),
   166  
   167  	// Control registers.
   168  	"CR0":  ast.NewIdent("cr0"),
   169  	"CR1":  ast.NewIdent("cr1"),
   170  	"CR2":  ast.NewIdent("cr2"),
   171  	"CR3":  ast.NewIdent("cr3"),
   172  	"CR4":  ast.NewIdent("cr4"),
   173  	"CR5":  ast.NewIdent("cr5"),
   174  	"CR6":  ast.NewIdent("cr6"),
   175  	"CR7":  ast.NewIdent("cr7"),
   176  	"CR8":  ast.NewIdent("cr8"),
   177  	"CR9":  ast.NewIdent("cr9"),
   178  	"CR10": ast.NewIdent("cr10"),
   179  	"CR11": ast.NewIdent("cr11"),
   180  	"CR12": ast.NewIdent("cr12"),
   181  	"CR13": ast.NewIdent("cr13"),
   182  	"CR14": ast.NewIdent("cr14"),
   183  	"CR15": ast.NewIdent("cr15"),
   184  
   185  	// Debug registers.
   186  	"DR0":  ast.NewIdent("dr0"),
   187  	"DR1":  ast.NewIdent("dr1"),
   188  	"DR2":  ast.NewIdent("dr2"),
   189  	"DR3":  ast.NewIdent("dr3"),
   190  	"DR4":  ast.NewIdent("dr4"),
   191  	"DR5":  ast.NewIdent("dr5"),
   192  	"DR6":  ast.NewIdent("dr6"),
   193  	"DR7":  ast.NewIdent("dr7"),
   194  	"DR8":  ast.NewIdent("dr8"),
   195  	"DR9":  ast.NewIdent("dr9"),
   196  	"DR10": ast.NewIdent("dr10"),
   197  	"DR11": ast.NewIdent("dr11"),
   198  	"DR12": ast.NewIdent("dr12"),
   199  	"DR13": ast.NewIdent("dr13"),
   200  	"DR14": ast.NewIdent("dr14"),
   201  	"DR15": ast.NewIdent("dr15"),
   202  
   203  	// Task registers.
   204  	"TR0": ast.NewIdent("tr0"),
   205  	"TR1": ast.NewIdent("tr1"),
   206  	"TR2": ast.NewIdent("tr2"),
   207  	"TR3": ast.NewIdent("tr3"),
   208  	"TR4": ast.NewIdent("tr4"),
   209  	"TR5": ast.NewIdent("tr5"),
   210  	"TR6": ast.NewIdent("tr6"),
   211  	"TR7": ast.NewIdent("tr7"),
   212  }
   213  
   214  // getReg converts reg into a corresponding Go expression.
   215  func getReg(reg x86asm.Reg) ast.Expr {
   216  	return getRegFromString(reg.String())
   217  }
   218  
   219  // getRegFromString converts reg into a corresponding Go expression.
   220  func getRegFromString(reg string) ast.Expr {
   221  	if expr, ok := regs[reg]; ok {
   222  		return expr
   223  	}
   224  	log.Fatal(errutil.Newf("unable to lookup identifer for register %q", reg))
   225  	panic("unreachable")
   226  }
   227  
   228  // getMem converts mem into a corresponding Go expression.
   229  func getMem(mem x86asm.Mem) ast.Expr {
   230  	// TODO: Replace 1*x with x in Scale*Index.
   231  
   232  	// The general memory reference form is:
   233  	//    Segment:[Base+Scale*Index+Disp]
   234  
   235  	// ... + Disp
   236  	expr := &ast.BinaryExpr{}
   237  	if mem.Disp != 0 {
   238  		disp := createExpr(mem.Disp)
   239  		expr.Op = token.ADD
   240  		expr.Y = disp
   241  	}
   242  
   243  	// ... + (Scale*Index) + ...
   244  	if mem.Scale != 0 && mem.Index != 0 {
   245  		scale := createExpr(mem.Scale)
   246  		index := getReg(mem.Index)
   247  		product := &ast.BinaryExpr{
   248  			X:  scale,
   249  			Op: token.MUL,
   250  			Y:  index,
   251  		}
   252  		switch {
   253  		case expr.Y == nil:
   254  			// ... + (Scale*Index)
   255  			expr.Op = token.ADD
   256  			expr.Y = product
   257  		default:
   258  			// ... + (Scale*Index) + Disp
   259  			expr.X = product
   260  			expr.Op = token.ADD
   261  		}
   262  	}
   263  
   264  	// ... + Base + ...
   265  	if mem.Base != 0 {
   266  		base := getReg(mem.Base)
   267  		switch {
   268  		case expr.X == nil:
   269  			// Base + (Scale*Index)
   270  			// or
   271  			// Base + Disp
   272  			expr.X = base
   273  			expr.Op = token.ADD
   274  		case expr.Y == nil:
   275  			// ... + Base
   276  			expr.Op = token.ADD
   277  			expr.Y = base
   278  		default:
   279  			sum := &ast.BinaryExpr{
   280  				X:  expr.X,
   281  				Op: token.ADD,
   282  				Y:  expr.Y,
   283  			}
   284  			expr.X = base
   285  			expr.Op = token.ADD
   286  			expr.Y = sum
   287  		}
   288  	}
   289  
   290  	// TODO: Figure out how the calculation is affected by segment in:
   291  	//    Segment:[Base+Scale*Index+Disp]
   292  	if mem.Segment != 0 {
   293  		segment := getReg(mem.Segment)
   294  		_ = segment
   295  		fmt.Printf("%#v\n", mem)
   296  		log.Fatal(errutil.Newf("support for Mem.Segment not yet implemented"))
   297  	}
   298  
   299  	switch {
   300  	case expr.X == nil && expr.Y == nil:
   301  		fmt.Printf("%#v\n", mem)
   302  		log.Fatal(errutil.New("support for memory reference to address zero not yet implemented"))
   303  		panic("unreachable")
   304  	case expr.X == nil && expr.Y != nil:
   305  		return createPtrDeref(expr.Y)
   306  	case expr.X != nil && expr.Y == nil:
   307  		return createPtrDeref(expr.X)
   308  	default:
   309  		return createPtrDeref(expr)
   310  	}
   311  }
   312  
   313  // createPtrDeref returns a pointer dereference expression of addr.
   314  func createPtrDeref(addr ast.Expr) ast.Expr {
   315  	return &ast.StarExpr{X: &ast.ParenExpr{X: addr}}
   316  }
   317  
   318  // createExpr converts x into a corresponding Go expression.
   319  func createExpr(x interface{}) ast.Expr {
   320  	switch x := x.(type) {
   321  	case int:
   322  		s := strconv.FormatInt(int64(x), 10)
   323  		return &ast.BasicLit{Kind: token.INT, Value: s}
   324  	case int64:
   325  		s := strconv.FormatInt(x, 10)
   326  		return &ast.BasicLit{Kind: token.INT, Value: s}
   327  	case uint8:
   328  		s := strconv.FormatUint(uint64(x), 10)
   329  		return &ast.BasicLit{Kind: token.INT, Value: s}
   330  	}
   331  	log.Fatal(errutil.Newf("support for type %T not yet implemented", x))
   332  	panic("unreachable")
   333  }
   334  
   335  // fromSubReg returns an equivalent expression to x, where x may be a sub-
   336  // register.
   337  func fromSubReg(sub ast.Expr) ast.Expr {
   338  	// TODO: Handle sub-registers (al, ah, ax)
   339  
   340  	// TODO: Fix operator precedence for C.
   341  	//    warning: & has lower precedence than <; < will be evaluated first
   342  	//    cf = *((int8_t *)ebp + -1) < ebx&255;
   343  
   344  	// Handle sub-registers (e.g. al, ah, ax).
   345  	if isSubLow8(sub) {
   346  		// Before:
   347  		//    al
   348  		// After:
   349  		//    eax&0x000000FF
   350  		return &ast.BinaryExpr{
   351  			X:  extendSubReg(sub),
   352  			Op: token.AND,
   353  			Y:  createExpr(0x000000FF),
   354  		}
   355  	}
   356  	if isSubHigh8(sub) {
   357  		// Before:
   358  		//    ah
   359  		// After:
   360  		//    (eax&0x0000FF00)>>8
   361  		paren := &ast.ParenExpr{
   362  			X: &ast.BinaryExpr{
   363  				X:  extendSubReg(sub),
   364  				Op: token.AND,
   365  				Y:  createExpr(0x0000FF00),
   366  			},
   367  		}
   368  		return &ast.BinaryExpr{
   369  			X:  paren,
   370  			Op: token.SHR,
   371  			Y:  createExpr(8),
   372  		}
   373  	}
   374  	if isSub16(sub) {
   375  		panic("not yet implemented.")
   376  	}
   377  	return sub
   378  }
   379  
   380  // subLow8 maps lower 8-bit sub-registers to their parent register.
   381  var subLow8 = map[string]string{
   382  	"al":   "EAX",
   383  	"cl":   "ECX",
   384  	"dl":   "EDX",
   385  	"bl":   "EBX",
   386  	"spb":  "ESP",
   387  	"bpb":  "EBP",
   388  	"sib":  "ESI",
   389  	"dib":  "EDI",
   390  	"r8b":  "R8L",
   391  	"r9b":  "R9L",
   392  	"r10b": "R10L",
   393  	"r11b": "R11L",
   394  	"r12b": "R12L",
   395  	"r13b": "R13L",
   396  	"r14b": "R14L",
   397  	"r15b": "R15L",
   398  }
   399  
   400  // isSubLow8 reports whether x is a lower 8-bit sub-register.
   401  func isSubLow8(x ast.Expr) bool {
   402  	if sub, ok := x.(*ast.Ident); ok {
   403  		_, ok = subLow8[sub.Name]
   404  		return ok
   405  	}
   406  	return false
   407  }
   408  
   409  // subHigh8 maps higher 8-bit sub-registers to their parent register.
   410  var subHigh8 = map[string]string{
   411  	"ah": "EAX",
   412  	"ch": "ECX",
   413  	"dh": "EDX",
   414  	"bh": "EBX",
   415  }
   416  
   417  // isSubHigh8 reports whether x is a higher 8-bit sub-register.
   418  func isSubHigh8(x ast.Expr) bool {
   419  	if sub, ok := x.(*ast.Ident); ok {
   420  		_, ok = subHigh8[sub.Name]
   421  		return ok
   422  	}
   423  	return false
   424  }
   425  
   426  // sub16 maps 16-bit sub-registers to their parent register.
   427  var sub16 = map[string]string{
   428  	"ax":   "EAX",
   429  	"cx":   "ECX",
   430  	"dx":   "EDX",
   431  	"bx":   "EBX",
   432  	"sp":   "ESP",
   433  	"bp":   "EBP",
   434  	"si":   "ESI",
   435  	"di":   "EDI",
   436  	"r8w":  "R8L",
   437  	"r9w":  "R9L",
   438  	"r10w": "R10L",
   439  	"r11w": "R11L",
   440  	"r12w": "R12L",
   441  	"r13w": "R13L",
   442  	"r14w": "R14L",
   443  	"r15w": "R15L",
   444  }
   445  
   446  // isSub16 reports whether x is a 16-bit sub-register.
   447  func isSub16(x ast.Expr) bool {
   448  	if sub, ok := x.(*ast.Ident); ok {
   449  		_, ok = sub16[sub.Name]
   450  		return ok
   451  	}
   452  	return false
   453  }
   454  
   455  // extendSubReg returns the parent register of x if x is a sub-register.
   456  func extendSubReg(x ast.Expr) ast.Expr {
   457  	sub, ok := x.(*ast.Ident)
   458  	if !ok {
   459  		return x
   460  	}
   461  	// Lower 8-bit sub-registers.
   462  	if reg, ok := subLow8[sub.Name]; ok {
   463  		return getRegFromString(reg)
   464  	}
   465  	// Higher 8-bit sub-registers.
   466  	if reg, ok := subHigh8[sub.Name]; ok {
   467  		return getRegFromString(reg)
   468  	}
   469  	// 16-bit sub-registers.
   470  	if reg, ok := sub16[sub.Name]; ok {
   471  		return getRegFromString(reg)
   472  	}
   473  	return x
   474  }