github.com/cowsed/Parser@v0.0.0-20211216032244-48b10019d380/jit.go (about)

     1  package parser
     2  
     3  import (
     4  	"fmt"
     5  	"syscall"
     6  	"unsafe"
     7  )
     8  
     9  //PrintHexBytes prints out the byes in a hex format
    10  func PrintHexBytes(s []byte) {
    11  	fmt.Print("[")
    12  	for i := range s {
    13  		fmt.Printf("0x%x ", s[i])
    14  	}
    15  	fmt.Println("]")
    16  }
    17  
    18  //JitCompileExpression compiles the expression to jitable code
    19  func JitCompileExpression(e Expression) func(vs map[string]float64) float64 {
    20  	mm := NewMemoryManager()
    21  	//Compile to intermediate bytecode
    22  	e.Compile(&mm)
    23  	//Compile to jitable code
    24  	f := JitCompile(&mm)
    25  	return f
    26  }
    27  
    28  var header = []uint8{0x48, 0x83, 0xec, 0x18, 0x48, 0x89, 0x6c, 0x24, 0x10, 0x48, 0x8d, 0x6c, 0x24, 0x10, 0x48, 0x89,
    29  	0x44, 0x24, 0x20, 0x48, 0x85, 0xdb, 0x76, 0x43}
    30  var footer = []uint8{0x48, 0x8b, 0x6c, 0x24,
    31  	0x10, 0x48, 0x83, 0xc4, 0x18, 0x90, 0xc3, 0xb8, 0x02, 0x00, 0x00, 0x00, 0x48, 0x89, 0xd9, 0xe8}
    32  
    33  //JitCompile Takes a completed memory manager of intermediate bytecode to jitable code
    34  func JitCompile(mm *MemoryManager) func(vs map[string]float64) float64 {
    35  	//Track of which vars are needed
    36  	vars := make([]string, len(mm.varLocations))
    37  	i := 0
    38  	for k := range mm.varLocations {
    39  		vars[i] = k
    40  		i++
    41  	}
    42  
    43  	//Create operating memory (memory in which operations are performed. only parts can be overwritten, others must stay unchanged for the function to wo0rk muiltiple times)
    44  	operating := make([]float64, len(mm.constants))
    45  	for i := range operating {
    46  		operating[i] = mm.constants[i]
    47  	}
    48  	code := []uint8{}
    49  	code = append(code, header...)
    50  	//Translate intermediate representation into x86 Assembly
    51  	/*
    52  		for i := 0; i < len(mm.bc); {
    53  			ins := mm.bc[i]
    54  			switch ins {
    55  			case AddBytecode:
    56  				Ai := mm.bc[i+1]
    57  				Bi := mm.bc[i+2]
    58  				Ri := mm.bc[i+3]
    59  
    60  				//float64s are 8 byte long
    61  				AiMem := uint8(Ai) * 8
    62  				BiMem := uint8(Bi) * 8
    63  				RiMem := uint8(Ri) * 8
    64  
    65  				tempCode := []uint8{
    66  					//Load A from memory
    67  					0xf2, 0x0f, 0x10, 0x40, AiMem,
    68  
    69  					//Add B from memory to A
    70  					0xf2, 0x0f, 0x58, 0x40, BiMem,
    71  
    72  					//Save to R in memory
    73  					0xf2, 0x0f, 0x11, 0x40, RiMem,
    74  				}
    75  				code = append(code, tempCode...)
    76  				i += 4
    77  			case SubBytecode:
    78  				Ai := mm.bc[i+1]
    79  				Bi := mm.bc[i+2]
    80  				Ri := mm.bc[i+3]
    81  
    82  				//float64s are 8 byte long
    83  				AiMem := uint8(Ai) * 8
    84  				BiMem := uint8(Bi) * 8
    85  				RiMem := uint8(Ri) * 8
    86  
    87  				tempCode := []uint8{
    88  					//Load A from memory
    89  					0xf2, 0x0f, 0x10, 0x40, AiMem,
    90  
    91  					//Sub B from memory from A
    92  					0xf2, 0x0f, 0x5c, 0x40, BiMem,
    93  
    94  					//Save to R in memory
    95  					0xf2, 0x0f, 0x11, 0x40, RiMem,
    96  				}
    97  				code = append(code, tempCode...)
    98  				i += 4
    99  			default:
   100  				fmt.Println("this shouldnt happen. Unrecognized function:", ins)
   101  				return nil
   102  			}
   103  
   104  		}
   105  	*/
   106  
   107  	code = append(code, footer...)
   108  
   109  	AsmFunction := MakeMathFunc(code)
   110  	if AsmFunction == nil {
   111  		panic(fmt.Errorf("should not happen. Nil function"))
   112  	}
   113  	fmt.Println(&AsmFunction)
   114  	return func(vs map[string]float64) float64 {
   115  		//Place variables in operating memory
   116  		for _, k := range vars {
   117  			index := mm.varLocations[k]
   118  			val := vs[k]
   119  			operating[index] = val
   120  
   121  		}
   122  		fmt.Println("operating mem", operating)
   123  		fmt.Println("executing")
   124  		res := AsmFunction(operating)
   125  		fmt.Println("result")
   126  		fmt.Println(res)
   127  		return res
   128  	}
   129  }
   130  
   131  //MakeMathFunc takes a bytecode and data and turns it into a function
   132  func MakeMathFunc(mathFunction []uint8) func([]float64) float64 {
   133  	type floatFunc func([]float64) float64
   134  
   135  	fmt.Println("function length", len(mathFunction))
   136  	if len(mathFunction) > 128 {
   137  		panic(fmt.Errorf("function too long for memory alloted"))
   138  	}
   139  	PrintHexBytes(mathFunction)
   140  
   141  	executablePrintFunc, err := syscall.Mmap(
   142  		-1,
   143  		0,
   144  		128,
   145  		syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC,
   146  		syscall.MAP_PRIVATE|syscall.MAP_ANONYMOUS)
   147  	if err != nil {
   148  		fmt.Printf("mmap err: %v", err)
   149  	}
   150  
   151  	copy(executablePrintFunc, mathFunction) ///When going back this is where it gets switched out for debug function
   152  
   153  	PrintHexBytes(executablePrintFunc)
   154  
   155  	unsafePrintFunc := (uintptr)(unsafe.Pointer(&executablePrintFunc))
   156  	function := *(*floatFunc)(unsafe.Pointer(&unsafePrintFunc))
   157  	return function
   158  }