github.com/decomp/exp@v0.0.0-20210624183419-6d058f5e1da6/lift/x86/terminator.go (about)

     1  package x86
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  
     7  	"github.com/decomp/exp/bin"
     8  	"github.com/decomp/exp/disasm/x86"
     9  	"github.com/kr/pretty"
    10  	"github.com/llir/llvm/ir"
    11  	"github.com/llir/llvm/ir/constant"
    12  	"github.com/llir/llvm/ir/types"
    13  	"github.com/pkg/errors"
    14  	"golang.org/x/arch/x86/x86asm"
    15  )
    16  
    17  // liftTerm lifts the given x86 terminator to LLVM IR, emitting code to f.
    18  func (f *Func) liftTerm(term *x86.Inst) error {
    19  	// Handle implicit fallthrough terminators.
    20  	if term.IsDummyTerm() {
    21  		dbg.Printf("lifting implicit terminator: JMP %v", term.Addr)
    22  		next, ok := f.blocks[term.Addr]
    23  		if !ok {
    24  			return errors.Errorf("unable to locate basic block at %v", term.Addr)
    25  		}
    26  		f.cur.NewBr(next)
    27  		return nil
    28  	}
    29  
    30  	dbg.Println("lifting terminator:", term.Inst)
    31  
    32  	// Check if prefix is present.
    33  	for _, prefix := range term.Prefix[:] {
    34  		// The first zero in the array marks the end of the prefixes.
    35  		if prefix == 0 {
    36  			break
    37  		}
    38  		switch prefix {
    39  		case x86asm.PrefixData16, x86asm.PrefixData16 | x86asm.PrefixImplicit:
    40  			// prefix already supported.
    41  		default:
    42  			pretty.Println("terminator with prefix:", term)
    43  			panic(fmt.Errorf("support for %v terminator with prefix not yet implemented", term.Op))
    44  		}
    45  	}
    46  
    47  	// Translate terminator.
    48  	switch term.Op {
    49  	// Loop terminators.
    50  	case x86asm.LOOP:
    51  		return f.liftTermLOOP(term)
    52  	case x86asm.LOOPE:
    53  		return f.liftTermLOOPE(term)
    54  	case x86asm.LOOPNE:
    55  		return f.liftTermLOOPNE(term)
    56  	// Conditional jump terminators.
    57  	case x86asm.JA:
    58  		return f.liftTermJA(term)
    59  	case x86asm.JAE:
    60  		return f.liftTermJAE(term)
    61  	case x86asm.JB:
    62  		return f.liftTermJB(term)
    63  	case x86asm.JBE:
    64  		return f.liftTermJBE(term)
    65  	case x86asm.JCXZ:
    66  		return f.liftTermJCXZ(term)
    67  	case x86asm.JE:
    68  		return f.liftTermJE(term)
    69  	case x86asm.JECXZ:
    70  		return f.liftTermJECXZ(term)
    71  	case x86asm.JG:
    72  		return f.liftTermJG(term)
    73  	case x86asm.JGE:
    74  		return f.liftTermJGE(term)
    75  	case x86asm.JL:
    76  		return f.liftTermJL(term)
    77  	case x86asm.JLE:
    78  		return f.liftTermJLE(term)
    79  	case x86asm.JNE:
    80  		return f.liftTermJNE(term)
    81  	case x86asm.JNO:
    82  		return f.liftTermJNO(term)
    83  	case x86asm.JNP:
    84  		return f.liftTermJNP(term)
    85  	case x86asm.JNS:
    86  		return f.liftTermJNS(term)
    87  	case x86asm.JO:
    88  		return f.liftTermJO(term)
    89  	case x86asm.JP:
    90  		return f.liftTermJP(term)
    91  	case x86asm.JRCXZ:
    92  		return f.liftTermJRCXZ(term)
    93  	case x86asm.JS:
    94  		return f.liftTermJS(term)
    95  	// Unconditional jump terminators.
    96  	case x86asm.JMP:
    97  		return f.liftTermJMP(term)
    98  	// Return terminators.
    99  	case x86asm.RET:
   100  		return f.liftTermRET(term)
   101  	default:
   102  		panic(fmt.Errorf("support for x86 terminator opcode %v not yet implemented", term.Op))
   103  	}
   104  }
   105  
   106  // --- [ JMP ] -----------------------------------------------------------------
   107  
   108  // liftTermJMP lifts the given x86 JMP terminator to LLVM IR, emitting code to
   109  // f.
   110  func (f *Func) liftTermJMP(term *x86.Inst) error {
   111  	// Handle tail calls.
   112  	if f.isTailCall(term) {
   113  		// Hack: interpret the JMP instruction as a CALL instruction. This works
   114  		// since emitInstCALL only interprets inst.Args[0], which is the same in
   115  		// both JMP and CALL instructions.
   116  		if err := f.liftInstCALL(term); err != nil {
   117  			return errors.WithStack(err)
   118  		}
   119  		// Handle return values.
   120  		if !types.Equal(f.Sig.RetType, types.Void) {
   121  			// Non-void functions, pass return value in EAX.
   122  			result := f.useReg(x86.EAX)
   123  			f.cur.NewRet(result)
   124  			return nil
   125  		}
   126  		f.cur.NewRet(nil)
   127  		return nil
   128  	}
   129  
   130  	// Handle static jump.
   131  	arg := term.Arg(0)
   132  	if targetAddr, ok := f.getAddr(arg); ok {
   133  		target, ok := f.blocks[targetAddr]
   134  		if !ok {
   135  			return errors.Errorf("unable to locate target basic block at %v", targetAddr)
   136  		}
   137  		f.cur.NewBr(target)
   138  		return nil
   139  	}
   140  	// Handle jump tables.
   141  	if _, ok := arg.Arg.(x86asm.Mem); ok {
   142  		mem := term.Mem(0)
   143  		if targetAddrs, ok := f.l.Tables[bin.Address(mem.Disp)]; ok {
   144  			// TODO: Implement proper support for jump table translation. The
   145  			// current implementation makes a range of assumptions, which do not
   146  			// hold true in the general case; e.g. assuming that mem.Base == 0 && mem.Scale == 4.
   147  			if mem.Mem.Base != 0 {
   148  				panic("support for jump table memory reference with base register not yet implemented")
   149  			}
   150  			if mem.Scale != 4 {
   151  				panic(fmt.Errorf("support for jump table memory reference with scale %d not yet implemented", mem.Scale))
   152  			}
   153  
   154  			// TODO: Locate default target using information from symbolic
   155  			// execution and predecessor basic blocks.
   156  
   157  			// At this stage of recovery, the assumption is `index` is always
   158  			// within the range of the jump table offsets. Thus, the default branch
   159  			// is always unreachable.
   160  			//
   161  			// This assumption will be validated and revisited when information
   162  			// from symbolic execution is available.
   163  
   164  			// TODO: Add support for indirect jump tables; i.e.
   165  			//
   166  			//    targets[values[index]]
   167  			index := f.useReg(mem.Index())
   168  			unreachable := &ir.Block{}
   169  			unreachable.NewUnreachable()
   170  			f.Blocks = append(f.Blocks, unreachable)
   171  			targetDefault := unreachable
   172  			var cases []*ir.Case
   173  			for i, targetAddr := range targetAddrs {
   174  				target, ok := f.blocks[targetAddr]
   175  				if !ok {
   176  					return errors.Errorf("unable to locate basic block at %v", targetAddr)
   177  				}
   178  				ii := constant.NewInt(index.Type().(*types.IntType), int64(i))
   179  				c := ir.NewCase(ii, target)
   180  				cases = append(cases, c)
   181  			}
   182  			f.cur.NewSwitch(index, targetDefault, cases...)
   183  			return nil
   184  		}
   185  	}
   186  	pretty.Println("term:", term)
   187  	panic("emitTermJMP: not yet implemented")
   188  }
   189  
   190  // --- [ RET ] -----------------------------------------------------------------
   191  
   192  // liftTermRET lifts the given x86 RET terminator to LLVM IR, emitting code to
   193  // f.
   194  func (f *Func) liftTermRET(term *x86.Inst) error {
   195  	// Handle return values of non-void functions (passed through EAX).
   196  	if !types.Equal(f.Sig.RetType, types.Void) {
   197  		result := f.useReg(x86.EAX)
   198  		f.cur.NewRet(result)
   199  		return nil
   200  	}
   201  	f.cur.NewRet(nil)
   202  	return nil
   203  }
   204  
   205  // === [ Helper functions ] ====================================================
   206  
   207  // isTailCall reports whether the given instruction is a tail call instruction.
   208  func (f *Func) isTailCall(inst *x86.Inst) bool {
   209  	arg := inst.Arg(0)
   210  	if target, ok := f.getAddr(arg); ok {
   211  		if f.contains(target) {
   212  			return false
   213  		}
   214  		if !f.l.IsFunc(target) {
   215  			dbg.Println("arg:", arg)
   216  			pretty.Println(arg)
   217  			panic(fmt.Errorf("tail call to non-function address %v", target))
   218  		}
   219  		return true
   220  	}
   221  	// Target read from jump table (e.g. switch statement).
   222  	if mem, ok := arg.Arg.(x86asm.Mem); ok {
   223  		addr := bin.Address(mem.Disp)
   224  		if targets, ok := f.l.Tables[addr]; ok {
   225  			for _, target := range targets {
   226  				if !f.contains(target) {
   227  					if !f.l.IsFunc(target) {
   228  						dbg.Println("arg:", arg)
   229  						pretty.Println(arg)
   230  						panic(fmt.Errorf("tail call to non-function address %v", target))
   231  					}
   232  					return true
   233  				}
   234  			}
   235  			return false
   236  		}
   237  	}
   238  
   239  	// TODO: Find a prettier solution for handling indirect jumps to potential
   240  	// tail call functions at register relative memory locations; e.g.
   241  	//    JMP [EAX+0x8]
   242  
   243  	// HACK: set the current basic block to a dummy basic block so that we may
   244  	// invoke f.getFunc (which may emit load instructions) to figure out if we
   245  	// are jumping to a function.
   246  	cur := f.cur
   247  	dummy := &ir.Block{}
   248  	f.cur = dummy
   249  	_, _, _, ok := f.getFunc(arg)
   250  	f.cur = cur
   251  	if ok {
   252  		return true
   253  	}
   254  
   255  	dbg.Println("arg:", arg)
   256  	pretty.Println(arg)
   257  	panic("not yet implemented")
   258  }
   259  
   260  // contains reports whether the target address is part of the address space of
   261  // the function.
   262  func (f *Func) contains(target bin.Address) bool {
   263  	// Target inside function address range.
   264  	entry := f.AsmFunc.Addr
   265  	funcEnd := f.l.getFuncEndAddr(entry)
   266  	if entry <= target && target < funcEnd {
   267  		return true
   268  	}
   269  	// Target inside function chunk.
   270  	if funcAddr, ok := f.l.Chunks[target]; ok {
   271  		if funcAddr[entry] {
   272  			return true
   273  		}
   274  	}
   275  	// Target is an imported function.
   276  	if _, ok := f.l.File.Imports[target]; ok {
   277  		return false
   278  	}
   279  	return false
   280  }
   281  
   282  // getFuncEndAddr returns the end address of the given function.
   283  func (l *Lifter) getFuncEndAddr(entry bin.Address) bin.Address {
   284  	less := func(i int) bool {
   285  		return entry < l.FuncAddrs[i]
   286  	}
   287  	index := sort.Search(len(l.FuncAddrs), less)
   288  	if index < len(l.FuncAddrs) {
   289  		return l.FuncAddrs[index]
   290  	}
   291  	return l.getCodeEnd()
   292  }
   293  
   294  //// getCodeStart returns the start address of the code section.
   295  //func (l *Lifter) getCodeStart() bin.Address {
   296  //	return bin.Address(d.imageBase + d.codeBase)
   297  //}
   298  
   299  // getCodeEnd returns the end address of the code section.
   300  func (l *Lifter) getCodeEnd() bin.Address {
   301  	var max bin.Address
   302  	for _, sect := range l.File.Sections {
   303  		if sect.Perm&bin.PermX != 0 {
   304  			end := sect.Addr + bin.Address(len(sect.Data))
   305  			if max < end {
   306  				max = end
   307  			}
   308  		}
   309  	}
   310  	if max == 0 {
   311  		panic("unable to locate end of code segment")
   312  	}
   313  	return max
   314  }