github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/ejit/main.go (about)

     1  //
     2  //  WARNING
     3  //
     4  //  DO NOT USE IN PRODUCTION
     5  //  IT IS VERY UNSAFE
     6  //
     7  //  AND MOSTLY FOR FUN
     8  //
     9  //  USE A PROPER PACKAGE FOR JIT INSTEAD OF WRITING YOUR OWN !!!
    10  //
    11  package main
    12  
    13  import (
    14  	"fmt"
    15  	"math"
    16  	"syscall"
    17  	"time"
    18  	"unsafe"
    19  )
    20  
    21  type Memory [1 << 8]float64
    22  
    23  type Var byte
    24  type Model struct {
    25  	Vars   map[string]Var
    26  	Eqs    map[string]*Eq
    27  	Memory Memory
    28  
    29  	lastVar Var
    30  }
    31  
    32  func NewModel() *Model {
    33  	return &Model{
    34  		Vars: make(map[string]Var, 512),
    35  		Eqs:  make(map[string]*Eq, 512),
    36  	}
    37  }
    38  
    39  func (m *Model) AddEq(name string, expr *Expr) {
    40  	eq := &Eq{}
    41  	m.Eqs[name] = eq
    42  	eq.Compile(expr, m)
    43  }
    44  
    45  func (m *Model) nextVar() Var {
    46  	m.lastVar++
    47  	return m.lastVar
    48  }
    49  
    50  func (m *Model) Run(name string) float64 {
    51  	eq, ok := m.Eqs[name]
    52  	if !ok {
    53  		panic("eq " + name + " not defined")
    54  	}
    55  	return eq.Run(m)
    56  }
    57  
    58  func (m *Model) SetVar(name string, value float64) Var {
    59  	ix, ok := m.Vars[name]
    60  	if !ok {
    61  		ix = m.nextVar()
    62  		m.Vars[name] = ix
    63  	}
    64  	m.Memory[ix] = value
    65  	return ix
    66  }
    67  
    68  func (m *Model) GetVar(name string) (float64, Var) {
    69  	ix, ok := m.Vars[name]
    70  	if !ok {
    71  		ix = m.nextVar()
    72  		m.Vars[name] = ix
    73  	}
    74  	return m.Memory[ix], ix
    75  }
    76  
    77  type Eq struct {
    78  	Expr   *Expr
    79  	Code   []Statement
    80  	Func   *Func
    81  	Result Var
    82  }
    83  
    84  type Op byte
    85  
    86  const (
    87  	Add = Op('+')
    88  	Sub = Op('-')
    89  	Mul = Op('*')
    90  	Div = Op('/')
    91  )
    92  
    93  type Statement struct {
    94  	// D := A op B
    95  	Op      Op
    96  	D, A, B Var
    97  }
    98  
    99  // Compile assumes that *Expr represents a tree
   100  func (eq *Eq) Compile(expr *Expr, model *Model) {
   101  	eq.Expr = expr
   102  	eq.Result = eq.compile(expr, model)
   103  
   104  	fn := &Func{}
   105  	fn.InitRAX()
   106  	for _, stmt := range eq.Code {
   107  		fn.BinaryOp(stmt.Op, stmt.D, stmt.A, stmt.B)
   108  	}
   109  	fn.Ret()
   110  	fn.MarkExecutable()
   111  	eq.Func = fn
   112  }
   113  
   114  func (eq *Eq) compile(expr *Expr, model *Model) Var {
   115  	switch expr.Type {
   116  	case ExprBinary:
   117  		left := eq.compile(expr.A, model)
   118  		right := eq.compile(expr.B, model)
   119  		ix := model.nextVar()
   120  		eq.Code = append(eq.Code, Statement{expr.Op, ix, left, right})
   121  		return ix
   122  	case ExprConst:
   123  		ix := model.nextVar()
   124  		model.Memory[ix] = expr.Value
   125  		return ix
   126  	case ExprVar:
   127  		_, ix := model.GetVar(expr.VarName)
   128  		return ix
   129  	}
   130  	panic("invalid expr")
   131  }
   132  
   133  func (eq *Eq) RunJIT(m *Model) float64 {
   134  	mem := &m.Memory
   135  	eq.Func.Call(mem)
   136  	return mem[eq.Result]
   137  }
   138  
   139  func (eq *Eq) Run(m *Model) float64 {
   140  	mem := &m.Memory
   141  	for _, stmt := range eq.Code {
   142  		mem[stmt.D] = stmt.Op.Eval(mem[stmt.A], mem[stmt.B])
   143  	}
   144  	return mem[eq.Result]
   145  }
   146  
   147  func (op Op) Eval(A, B float64) float64 {
   148  	switch op {
   149  	case Add:
   150  		return A + B
   151  	case Sub:
   152  		return A - B
   153  	case Mul:
   154  		return A * B
   155  	case Div:
   156  		return A / B
   157  	}
   158  	return math.NaN()
   159  }
   160  
   161  type ExprType byte
   162  
   163  const (
   164  	ExprInvalid = ExprType(iota)
   165  	ExprBinary
   166  	ExprConst
   167  	ExprVar
   168  )
   169  
   170  // Expr represents arbitrary ast node
   171  type Expr struct {
   172  	Type    ExprType
   173  	VarName string
   174  	Value   float64
   175  	Op      Op
   176  	A, B    *Expr
   177  }
   178  
   179  func (expr *Expr) Optimize() *Expr {
   180  	switch expr.Type {
   181  	case ExprBinary:
   182  		left := expr.A.Optimize()
   183  		right := expr.B.Optimize()
   184  
   185  		if left.Type == ExprConst && right.Type == ExprConst {
   186  			return &Expr{
   187  				Type:  ExprConst,
   188  				Value: expr.Op.Eval(left.Value, right.Value),
   189  			}
   190  		}
   191  		return &Expr{
   192  			Type: ExprBinary,
   193  			Op:   expr.Op,
   194  			A:    left,
   195  			B:    right,
   196  		}
   197  	case ExprConst:
   198  		return expr
   199  	case ExprVar:
   200  		return expr
   201  	}
   202  	panic("invalid expr")
   203  }
   204  
   205  func (expr *Expr) String() string {
   206  	switch expr.Type {
   207  	case ExprBinary:
   208  		return fmt.Sprintf("(%v %v %v)", expr.A, string(expr.Op), expr.B)
   209  	case ExprConst:
   210  		return fmt.Sprintf("%v", expr.Value)
   211  	case ExprVar:
   212  		return expr.VarName
   213  	}
   214  	return "???"
   215  }
   216  
   217  // Helpers to make create expressions easier
   218  func EOp(left *Expr, op Op, right *Expr) *Expr {
   219  	return &Expr{Type: ExprBinary, Op: op, A: left, B: right}
   220  }
   221  func EAdd(left, right *Expr) *Expr { return EOp(left, Add, right) }
   222  func ESub(left, right *Expr) *Expr { return EOp(left, Sub, right) }
   223  func EMul(left, right *Expr) *Expr { return EOp(left, Mul, right) }
   224  func EDiv(left, right *Expr) *Expr { return EOp(left, Div, right) }
   225  func EVar(name string) *Expr       { return &Expr{Type: ExprVar, VarName: name} }
   226  func EConst(value float64) *Expr   { return &Expr{Type: ExprConst, Value: value} }
   227  
   228  func main() {
   229  	// defer profile.Start(profile.CPUProfile).Stop()
   230  
   231  	model := NewModel()
   232  
   233  	// expr := "x + 2*y/(30-2/3) + (12-2)/5 + 1 + 2 + 3 + 4 + 5"
   234  	// parsing left as an exercise :D
   235  	expr := EAdd(
   236  		EVar("x"),
   237  		EAdd(
   238  			EDiv(
   239  				EMul(EConst(2), EVar("y")),
   240  				ESub(EConst(30), EDiv(EConst(2), EConst(3))),
   241  			),
   242  			EAdd(
   243  				EDiv(ESub(EConst(12), EConst(2)), EConst(5)),
   244  				EAdd(EConst(1),
   245  					EAdd(EConst(2),
   246  						EAdd(EConst(3),
   247  							EAdd(EConst(4), EConst(5))))),
   248  			),
   249  		),
   250  	)
   251  
   252  	fmt.Println(expr)
   253  	expr = expr.Optimize()
   254  	fmt.Println(expr)
   255  
   256  	model.SetVar("x", 0)
   257  	model.SetVar("y", 0)
   258  	model.AddEq("testeq", expr)
   259  
   260  	n := 1000 * 1000
   261  
   262  	x := []float64{3, 4, 12, -3.4, 20}
   263  	y := []float64{1, 2, 3, 4, 5}
   264  
   265  	{
   266  		// testing walking ASTs
   267  		t := time.Now()
   268  		result := float64(0)
   269  
   270  		// cache the map lookups
   271  		_, ix := model.GetVar("x")
   272  		_, iy := model.GetVar("y")
   273  		eq := model.Eqs["testeq"]
   274  
   275  		for j := 0; j < n; j++ {
   276  			model.Memory[ix] = x[3]
   277  			model.Memory[iy] = y[3]
   278  			result += eq.RunJIT(model)
   279  		}
   280  		fmt.Println(result)
   281  		fmt.Println(time.Since(t))
   282  	}
   283  
   284  	{
   285  		// testing just evaluating
   286  		t := time.Now()
   287  		result := float64(0)
   288  		for j := 0; j < n; j++ {
   289  			result += x[3] + 2*y[3]/(30-2.0/3.0) + (12-2)/5.0 + 1 + 2 + 3 + 4 + 5
   290  		}
   291  		fmt.Println(result)
   292  		fmt.Println(time.Since(t))
   293  	}
   294  }
   295  
   296  // must only use references, otherwise when the memory moves
   297  type Func struct {
   298  	body []byte
   299  	Call func(*Memory)
   300  }
   301  
   302  func (fn *Func) append(code ...byte) { fn.body = append(fn.body, code...) }
   303  func (fn *Func) append32(v uint32) {
   304  	fn.append(byte(v>>0), byte(v>>8), byte(v>>16), byte(v>>24))
   305  }
   306  
   307  func (fn *Func) mov_rsp8_rax() { fn.append(0x48, 0x8b, 0x44, 0x24, 0x08) }
   308  func (fn *Func) movsd_rax_X0(off uint32) {
   309  	fn.append(0xf2, 0x0f, 0x10, 0x80)
   310  	fn.append32(off)
   311  }
   312  func (fn *Func) movsd_rax_X1(off uint32) {
   313  	fn.append(0xf2, 0x0f, 0x10, 0x88)
   314  	fn.append32(off)
   315  }
   316  func (fn *Func) movsd_X0_rax(off uint32) {
   317  	fn.append(0xf2, 0x0f, 0x11, 0x80)
   318  	fn.append32(off)
   319  }
   320  
   321  func (fn *Func) add_X0_X1() { fn.append(0xf2, 0x0f, 0x58, 0xc1) }
   322  func (fn *Func) sub_X0_X1() { fn.append(0xf2, 0x0f, 0x5c, 0xc1) }
   323  func (fn *Func) mul_X0_X1() { fn.append(0xf2, 0x0f, 0x59, 0xc1) }
   324  func (fn *Func) div_X0_X1() { fn.append(0xf2, 0x0f, 0x5e, 0xc1) }
   325  
   326  func (fn *Func) InitRAX() { fn.mov_rsp8_rax() }
   327  func (fn *Func) Ret()     { fn.append(0xc3) }
   328  func (fn *Func) BinaryOp(op Op, dst, a, b Var) {
   329  	fn.movsd_rax_X0(uint32(a) * 8)
   330  	fn.movsd_rax_X1(uint32(b) * 8)
   331  	switch op {
   332  	case Add:
   333  		fn.add_X0_X1()
   334  	case Sub:
   335  		fn.sub_X0_X1()
   336  	case Mul:
   337  		fn.mul_X0_X1()
   338  	case Div:
   339  		fn.div_X0_X1()
   340  	}
   341  	fn.movsd_X0_rax(uint32(dst) * 8)
   342  }
   343  
   344  func (fn *Func) MarkExecutable() {
   345  	_, err := VirtualProtect(fn.body, 0x40)
   346  	if err != nil {
   347  		panic(err)
   348  	}
   349  
   350  	// OH GOD WHAT HAVE I DONE???
   351  	type callstub struct{ fn func(*Memory) }
   352  	var actual struct{ body **byte }
   353  	pbody := &fn.body[0]
   354  	actual.body = &pbody
   355  	stub := (*callstub)(unsafe.Pointer(&actual))
   356  	fn.Call = stub.fn
   357  }
   358  
   359  var (
   360  	modkernel32        = syscall.NewLazyDLL("kernel32.dll")
   361  	procVirtualProtect = modkernel32.NewProc("VirtualProtect")
   362  )
   363  
   364  func VirtualProtect(data []byte, newprotect uint) (oldprotect uint, err error) {
   365  	var op uintptr
   366  	r1, _, _ := procVirtualProtect.Call(uintptr(unsafe.Pointer(&data[0])), uintptr(len(data)), uintptr(newprotect), uintptr(unsafe.Pointer(&op)))
   367  	if r1 == 0 {
   368  		err = fmt.Errorf("error")
   369  	}
   370  	return uint(op), err
   371  }