github.com/cowsed/Parser@v0.0.0-20211216032244-48b10019d380/jit.go (about) 1 package parser 2 3 import ( 4 "fmt" 5 "syscall" 6 "unsafe" 7 ) 8 9 //PrintHexBytes prints out the byes in a hex format 10 func PrintHexBytes(s []byte) { 11 fmt.Print("[") 12 for i := range s { 13 fmt.Printf("0x%x ", s[i]) 14 } 15 fmt.Println("]") 16 } 17 18 //JitCompileExpression compiles the expression to jitable code 19 func JitCompileExpression(e Expression) func(vs map[string]float64) float64 { 20 mm := NewMemoryManager() 21 //Compile to intermediate bytecode 22 e.Compile(&mm) 23 //Compile to jitable code 24 f := JitCompile(&mm) 25 return f 26 } 27 28 var header = []uint8{0x48, 0x83, 0xec, 0x18, 0x48, 0x89, 0x6c, 0x24, 0x10, 0x48, 0x8d, 0x6c, 0x24, 0x10, 0x48, 0x89, 29 0x44, 0x24, 0x20, 0x48, 0x85, 0xdb, 0x76, 0x43} 30 var footer = []uint8{0x48, 0x8b, 0x6c, 0x24, 31 0x10, 0x48, 0x83, 0xc4, 0x18, 0x90, 0xc3, 0xb8, 0x02, 0x00, 0x00, 0x00, 0x48, 0x89, 0xd9, 0xe8} 32 33 //JitCompile Takes a completed memory manager of intermediate bytecode to jitable code 34 func JitCompile(mm *MemoryManager) func(vs map[string]float64) float64 { 35 //Track of which vars are needed 36 vars := make([]string, len(mm.varLocations)) 37 i := 0 38 for k := range mm.varLocations { 39 vars[i] = k 40 i++ 41 } 42 43 //Create operating memory (memory in which operations are performed. only parts can be overwritten, others must stay unchanged for the function to wo0rk muiltiple times) 44 operating := make([]float64, len(mm.constants)) 45 for i := range operating { 46 operating[i] = mm.constants[i] 47 } 48 code := []uint8{} 49 code = append(code, header...) 50 //Translate intermediate representation into x86 Assembly 51 /* 52 for i := 0; i < len(mm.bc); { 53 ins := mm.bc[i] 54 switch ins { 55 case AddBytecode: 56 Ai := mm.bc[i+1] 57 Bi := mm.bc[i+2] 58 Ri := mm.bc[i+3] 59 60 //float64s are 8 byte long 61 AiMem := uint8(Ai) * 8 62 BiMem := uint8(Bi) * 8 63 RiMem := uint8(Ri) * 8 64 65 tempCode := []uint8{ 66 //Load A from memory 67 0xf2, 0x0f, 0x10, 0x40, AiMem, 68 69 //Add B from memory to A 70 0xf2, 0x0f, 0x58, 0x40, BiMem, 71 72 //Save to R in memory 73 0xf2, 0x0f, 0x11, 0x40, RiMem, 74 } 75 code = append(code, tempCode...) 76 i += 4 77 case SubBytecode: 78 Ai := mm.bc[i+1] 79 Bi := mm.bc[i+2] 80 Ri := mm.bc[i+3] 81 82 //float64s are 8 byte long 83 AiMem := uint8(Ai) * 8 84 BiMem := uint8(Bi) * 8 85 RiMem := uint8(Ri) * 8 86 87 tempCode := []uint8{ 88 //Load A from memory 89 0xf2, 0x0f, 0x10, 0x40, AiMem, 90 91 //Sub B from memory from A 92 0xf2, 0x0f, 0x5c, 0x40, BiMem, 93 94 //Save to R in memory 95 0xf2, 0x0f, 0x11, 0x40, RiMem, 96 } 97 code = append(code, tempCode...) 98 i += 4 99 default: 100 fmt.Println("this shouldnt happen. Unrecognized function:", ins) 101 return nil 102 } 103 104 } 105 */ 106 107 code = append(code, footer...) 108 109 AsmFunction := MakeMathFunc(code) 110 if AsmFunction == nil { 111 panic(fmt.Errorf("should not happen. Nil function")) 112 } 113 fmt.Println(&AsmFunction) 114 return func(vs map[string]float64) float64 { 115 //Place variables in operating memory 116 for _, k := range vars { 117 index := mm.varLocations[k] 118 val := vs[k] 119 operating[index] = val 120 121 } 122 fmt.Println("operating mem", operating) 123 fmt.Println("executing") 124 res := AsmFunction(operating) 125 fmt.Println("result") 126 fmt.Println(res) 127 return res 128 } 129 } 130 131 //MakeMathFunc takes a bytecode and data and turns it into a function 132 func MakeMathFunc(mathFunction []uint8) func([]float64) float64 { 133 type floatFunc func([]float64) float64 134 135 fmt.Println("function length", len(mathFunction)) 136 if len(mathFunction) > 128 { 137 panic(fmt.Errorf("function too long for memory alloted")) 138 } 139 PrintHexBytes(mathFunction) 140 141 executablePrintFunc, err := syscall.Mmap( 142 -1, 143 0, 144 128, 145 syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC, 146 syscall.MAP_PRIVATE|syscall.MAP_ANONYMOUS) 147 if err != nil { 148 fmt.Printf("mmap err: %v", err) 149 } 150 151 copy(executablePrintFunc, mathFunction) ///When going back this is where it gets switched out for debug function 152 153 PrintHexBytes(executablePrintFunc) 154 155 unsafePrintFunc := (uintptr)(unsafe.Pointer(&executablePrintFunc)) 156 function := *(*floatFunc)(unsafe.Pointer(&unsafePrintFunc)) 157 return function 158 }