github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/ejit/main.go (about) 1 // 2 // WARNING 3 // 4 // DO NOT USE IN PRODUCTION 5 // IT IS VERY UNSAFE 6 // 7 // AND MOSTLY FOR FUN 8 // 9 // USE A PROPER PACKAGE FOR JIT INSTEAD OF WRITING YOUR OWN !!! 10 // 11 package main 12 13 import ( 14 "fmt" 15 "math" 16 "syscall" 17 "time" 18 "unsafe" 19 ) 20 21 type Memory [1 << 8]float64 22 23 type Var byte 24 type Model struct { 25 Vars map[string]Var 26 Eqs map[string]*Eq 27 Memory Memory 28 29 lastVar Var 30 } 31 32 func NewModel() *Model { 33 return &Model{ 34 Vars: make(map[string]Var, 512), 35 Eqs: make(map[string]*Eq, 512), 36 } 37 } 38 39 func (m *Model) AddEq(name string, expr *Expr) { 40 eq := &Eq{} 41 m.Eqs[name] = eq 42 eq.Compile(expr, m) 43 } 44 45 func (m *Model) nextVar() Var { 46 m.lastVar++ 47 return m.lastVar 48 } 49 50 func (m *Model) Run(name string) float64 { 51 eq, ok := m.Eqs[name] 52 if !ok { 53 panic("eq " + name + " not defined") 54 } 55 return eq.Run(m) 56 } 57 58 func (m *Model) SetVar(name string, value float64) Var { 59 ix, ok := m.Vars[name] 60 if !ok { 61 ix = m.nextVar() 62 m.Vars[name] = ix 63 } 64 m.Memory[ix] = value 65 return ix 66 } 67 68 func (m *Model) GetVar(name string) (float64, Var) { 69 ix, ok := m.Vars[name] 70 if !ok { 71 ix = m.nextVar() 72 m.Vars[name] = ix 73 } 74 return m.Memory[ix], ix 75 } 76 77 type Eq struct { 78 Expr *Expr 79 Code []Statement 80 Func *Func 81 Result Var 82 } 83 84 type Op byte 85 86 const ( 87 Add = Op('+') 88 Sub = Op('-') 89 Mul = Op('*') 90 Div = Op('/') 91 ) 92 93 type Statement struct { 94 // D := A op B 95 Op Op 96 D, A, B Var 97 } 98 99 // Compile assumes that *Expr represents a tree 100 func (eq *Eq) Compile(expr *Expr, model *Model) { 101 eq.Expr = expr 102 eq.Result = eq.compile(expr, model) 103 104 fn := &Func{} 105 fn.InitRAX() 106 for _, stmt := range eq.Code { 107 fn.BinaryOp(stmt.Op, stmt.D, stmt.A, stmt.B) 108 } 109 fn.Ret() 110 fn.MarkExecutable() 111 eq.Func = fn 112 } 113 114 func (eq *Eq) compile(expr *Expr, model *Model) Var { 115 switch expr.Type { 116 case ExprBinary: 117 left := eq.compile(expr.A, model) 118 right := eq.compile(expr.B, model) 119 ix := model.nextVar() 120 eq.Code = append(eq.Code, Statement{expr.Op, ix, left, right}) 121 return ix 122 case ExprConst: 123 ix := model.nextVar() 124 model.Memory[ix] = expr.Value 125 return ix 126 case ExprVar: 127 _, ix := model.GetVar(expr.VarName) 128 return ix 129 } 130 panic("invalid expr") 131 } 132 133 func (eq *Eq) RunJIT(m *Model) float64 { 134 mem := &m.Memory 135 eq.Func.Call(mem) 136 return mem[eq.Result] 137 } 138 139 func (eq *Eq) Run(m *Model) float64 { 140 mem := &m.Memory 141 for _, stmt := range eq.Code { 142 mem[stmt.D] = stmt.Op.Eval(mem[stmt.A], mem[stmt.B]) 143 } 144 return mem[eq.Result] 145 } 146 147 func (op Op) Eval(A, B float64) float64 { 148 switch op { 149 case Add: 150 return A + B 151 case Sub: 152 return A - B 153 case Mul: 154 return A * B 155 case Div: 156 return A / B 157 } 158 return math.NaN() 159 } 160 161 type ExprType byte 162 163 const ( 164 ExprInvalid = ExprType(iota) 165 ExprBinary 166 ExprConst 167 ExprVar 168 ) 169 170 // Expr represents arbitrary ast node 171 type Expr struct { 172 Type ExprType 173 VarName string 174 Value float64 175 Op Op 176 A, B *Expr 177 } 178 179 func (expr *Expr) Optimize() *Expr { 180 switch expr.Type { 181 case ExprBinary: 182 left := expr.A.Optimize() 183 right := expr.B.Optimize() 184 185 if left.Type == ExprConst && right.Type == ExprConst { 186 return &Expr{ 187 Type: ExprConst, 188 Value: expr.Op.Eval(left.Value, right.Value), 189 } 190 } 191 return &Expr{ 192 Type: ExprBinary, 193 Op: expr.Op, 194 A: left, 195 B: right, 196 } 197 case ExprConst: 198 return expr 199 case ExprVar: 200 return expr 201 } 202 panic("invalid expr") 203 } 204 205 func (expr *Expr) String() string { 206 switch expr.Type { 207 case ExprBinary: 208 return fmt.Sprintf("(%v %v %v)", expr.A, string(expr.Op), expr.B) 209 case ExprConst: 210 return fmt.Sprintf("%v", expr.Value) 211 case ExprVar: 212 return expr.VarName 213 } 214 return "???" 215 } 216 217 // Helpers to make create expressions easier 218 func EOp(left *Expr, op Op, right *Expr) *Expr { 219 return &Expr{Type: ExprBinary, Op: op, A: left, B: right} 220 } 221 func EAdd(left, right *Expr) *Expr { return EOp(left, Add, right) } 222 func ESub(left, right *Expr) *Expr { return EOp(left, Sub, right) } 223 func EMul(left, right *Expr) *Expr { return EOp(left, Mul, right) } 224 func EDiv(left, right *Expr) *Expr { return EOp(left, Div, right) } 225 func EVar(name string) *Expr { return &Expr{Type: ExprVar, VarName: name} } 226 func EConst(value float64) *Expr { return &Expr{Type: ExprConst, Value: value} } 227 228 func main() { 229 // defer profile.Start(profile.CPUProfile).Stop() 230 231 model := NewModel() 232 233 // expr := "x + 2*y/(30-2/3) + (12-2)/5 + 1 + 2 + 3 + 4 + 5" 234 // parsing left as an exercise :D 235 expr := EAdd( 236 EVar("x"), 237 EAdd( 238 EDiv( 239 EMul(EConst(2), EVar("y")), 240 ESub(EConst(30), EDiv(EConst(2), EConst(3))), 241 ), 242 EAdd( 243 EDiv(ESub(EConst(12), EConst(2)), EConst(5)), 244 EAdd(EConst(1), 245 EAdd(EConst(2), 246 EAdd(EConst(3), 247 EAdd(EConst(4), EConst(5))))), 248 ), 249 ), 250 ) 251 252 fmt.Println(expr) 253 expr = expr.Optimize() 254 fmt.Println(expr) 255 256 model.SetVar("x", 0) 257 model.SetVar("y", 0) 258 model.AddEq("testeq", expr) 259 260 n := 1000 * 1000 261 262 x := []float64{3, 4, 12, -3.4, 20} 263 y := []float64{1, 2, 3, 4, 5} 264 265 { 266 // testing walking ASTs 267 t := time.Now() 268 result := float64(0) 269 270 // cache the map lookups 271 _, ix := model.GetVar("x") 272 _, iy := model.GetVar("y") 273 eq := model.Eqs["testeq"] 274 275 for j := 0; j < n; j++ { 276 model.Memory[ix] = x[3] 277 model.Memory[iy] = y[3] 278 result += eq.RunJIT(model) 279 } 280 fmt.Println(result) 281 fmt.Println(time.Since(t)) 282 } 283 284 { 285 // testing just evaluating 286 t := time.Now() 287 result := float64(0) 288 for j := 0; j < n; j++ { 289 result += x[3] + 2*y[3]/(30-2.0/3.0) + (12-2)/5.0 + 1 + 2 + 3 + 4 + 5 290 } 291 fmt.Println(result) 292 fmt.Println(time.Since(t)) 293 } 294 } 295 296 // must only use references, otherwise when the memory moves 297 type Func struct { 298 body []byte 299 Call func(*Memory) 300 } 301 302 func (fn *Func) append(code ...byte) { fn.body = append(fn.body, code...) } 303 func (fn *Func) append32(v uint32) { 304 fn.append(byte(v>>0), byte(v>>8), byte(v>>16), byte(v>>24)) 305 } 306 307 func (fn *Func) mov_rsp8_rax() { fn.append(0x48, 0x8b, 0x44, 0x24, 0x08) } 308 func (fn *Func) movsd_rax_X0(off uint32) { 309 fn.append(0xf2, 0x0f, 0x10, 0x80) 310 fn.append32(off) 311 } 312 func (fn *Func) movsd_rax_X1(off uint32) { 313 fn.append(0xf2, 0x0f, 0x10, 0x88) 314 fn.append32(off) 315 } 316 func (fn *Func) movsd_X0_rax(off uint32) { 317 fn.append(0xf2, 0x0f, 0x11, 0x80) 318 fn.append32(off) 319 } 320 321 func (fn *Func) add_X0_X1() { fn.append(0xf2, 0x0f, 0x58, 0xc1) } 322 func (fn *Func) sub_X0_X1() { fn.append(0xf2, 0x0f, 0x5c, 0xc1) } 323 func (fn *Func) mul_X0_X1() { fn.append(0xf2, 0x0f, 0x59, 0xc1) } 324 func (fn *Func) div_X0_X1() { fn.append(0xf2, 0x0f, 0x5e, 0xc1) } 325 326 func (fn *Func) InitRAX() { fn.mov_rsp8_rax() } 327 func (fn *Func) Ret() { fn.append(0xc3) } 328 func (fn *Func) BinaryOp(op Op, dst, a, b Var) { 329 fn.movsd_rax_X0(uint32(a) * 8) 330 fn.movsd_rax_X1(uint32(b) * 8) 331 switch op { 332 case Add: 333 fn.add_X0_X1() 334 case Sub: 335 fn.sub_X0_X1() 336 case Mul: 337 fn.mul_X0_X1() 338 case Div: 339 fn.div_X0_X1() 340 } 341 fn.movsd_X0_rax(uint32(dst) * 8) 342 } 343 344 func (fn *Func) MarkExecutable() { 345 _, err := VirtualProtect(fn.body, 0x40) 346 if err != nil { 347 panic(err) 348 } 349 350 // OH GOD WHAT HAVE I DONE??? 351 type callstub struct{ fn func(*Memory) } 352 var actual struct{ body **byte } 353 pbody := &fn.body[0] 354 actual.body = &pbody 355 stub := (*callstub)(unsafe.Pointer(&actual)) 356 fn.Call = stub.fn 357 } 358 359 var ( 360 modkernel32 = syscall.NewLazyDLL("kernel32.dll") 361 procVirtualProtect = modkernel32.NewProc("VirtualProtect") 362 ) 363 364 func VirtualProtect(data []byte, newprotect uint) (oldprotect uint, err error) { 365 var op uintptr 366 r1, _, _ := procVirtualProtect.Call(uintptr(unsafe.Pointer(&data[0])), uintptr(len(data)), uintptr(newprotect), uintptr(unsafe.Pointer(&op))) 367 if r1 == 0 { 368 err = fmt.Errorf("error") 369 } 370 return uint(op), err 371 }