github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/compiler/compiler_test.go (about) 1 package compiler 2 3 import ( 4 "fmt" 5 "math" 6 "os" 7 "runtime" 8 "testing" 9 "unsafe" 10 11 "github.com/bananabytelabs/wazero/internal/platform" 12 "github.com/bananabytelabs/wazero/internal/testing/require" 13 "github.com/bananabytelabs/wazero/internal/wasm" 14 "github.com/bananabytelabs/wazero/internal/wazeroir" 15 ) 16 17 func TestMain(m *testing.M) { 18 if !platform.CompilerSupported() { 19 os.Exit(0) 20 } 21 os.Exit(m.Run()) 22 } 23 24 // Ensures that the offset consts do not drift when we manipulate the target 25 // structs. 26 // 27 // Note: This is a package initializer as many tests could fail if these 28 // constants are misaligned, hiding the root cause. 29 func init() { 30 var me moduleEngine 31 requireEqual := func(expected, actual int, name string) { 32 if expected != actual { 33 panic(fmt.Sprintf("%s: expected %d, but was %d", name, expected, actual)) 34 } 35 } 36 requireEqual(int(unsafe.Offsetof(me.functions)), moduleEngineFunctionsOffset, "moduleEngineFunctionsOffset") 37 38 var ce callEngine 39 // Offsets for callEngine.moduleContext. 40 requireEqual(int(unsafe.Offsetof(ce.fn)), callEngineModuleContextFnOffset, "callEngineModuleContextFnOffset") 41 requireEqual(int(unsafe.Offsetof(ce.moduleInstance)), callEngineModuleContextModuleInstanceOffset, "callEngineModuleContextModuleInstanceOffset") 42 requireEqual(int(unsafe.Offsetof(ce.globalElement0Address)), callEngineModuleContextGlobalElement0AddressOffset, "callEngineModuleContextGlobalElement0AddressOffset") 43 requireEqual(int(unsafe.Offsetof(ce.memoryElement0Address)), callEngineModuleContextMemoryElement0AddressOffset, "callEngineModuleContextMemoryElement0AddressOffset") 44 requireEqual(int(unsafe.Offsetof(ce.memorySliceLen)), callEngineModuleContextMemorySliceLenOffset, "callEngineModuleContextMemorySliceLenOffset") 45 requireEqual(int(unsafe.Offsetof(ce.memoryInstance)), callEngineModuleContextMemoryInstanceOffset, "callEngineModuleContextMemoryInstanceOffset") 46 requireEqual(int(unsafe.Offsetof(ce.tablesElement0Address)), callEngineModuleContextTablesElement0AddressOffset, "callEngineModuleContextTablesElement0AddressOffset") 47 requireEqual(int(unsafe.Offsetof(ce.functionsElement0Address)), callEngineModuleContextFunctionsElement0AddressOffset, "callEngineModuleContextFunctionsElement0AddressOffset") 48 requireEqual(int(unsafe.Offsetof(ce.typeIDsElement0Address)), callEngineModuleContextTypeIDsElement0AddressOffset, "callEngineModuleContextTypeIDsElement0AddressOffset") 49 requireEqual(int(unsafe.Offsetof(ce.dataInstancesElement0Address)), callEngineModuleContextDataInstancesElement0AddressOffset, "callEngineModuleContextDataInstancesElement0AddressOffset") 50 requireEqual(int(unsafe.Offsetof(ce.elementInstancesElement0Address)), callEngineModuleContextElementInstancesElement0AddressOffset, "callEngineModuleContextElementInstancesElement0AddressOffset") 51 52 // Offsets for callEngine.stackContext 53 requireEqual(int(unsafe.Offsetof(ce.stackPointer)), callEngineStackContextStackPointerOffset, "callEngineStackContextStackPointerOffset") 54 requireEqual(int(unsafe.Offsetof(ce.stackBasePointerInBytes)), callEngineStackContextStackBasePointerInBytesOffset, "callEngineStackContextStackBasePointerInBytesOffset") 55 requireEqual(int(unsafe.Offsetof(ce.stackElement0Address)), callEngineStackContextStackElement0AddressOffset, "callEngineStackContextStackElement0AddressOffset") 56 requireEqual(int(unsafe.Offsetof(ce.stackLenInBytes)), callEngineStackContextStackLenInBytesOffset, "callEngineStackContextStackLenInBytesOffset") 57 58 // Offsets for callEngine.exitContext. 59 requireEqual(int(unsafe.Offsetof(ce.statusCode)), callEngineExitContextNativeCallStatusCodeOffset, "callEngineExitContextNativeCallStatusCodeOffset") 60 requireEqual(int(unsafe.Offsetof(ce.builtinFunctionCallIndex)), callEngineExitContextBuiltinFunctionCallIndexOffset, "callEngineExitContextBuiltinFunctionCallIndexOffset") 61 requireEqual(int(unsafe.Offsetof(ce.returnAddress)), callEngineExitContextReturnAddressOffset, "callEngineExitContextReturnAddressOffset") 62 requireEqual(int(unsafe.Offsetof(ce.callerModuleInstance)), callEngineExitContextCallerModuleInstanceOffset, "callEngineExitContextCallerModuleInstanceOffset") 63 64 // Size and offsets for callFrame. 65 var frame callFrame 66 requireEqual(int(unsafe.Sizeof(frame))/8, callFrameDataSizeInUint64, "callFrameDataSize") 67 68 // Offsets for code. 69 var f function 70 requireEqual(int(unsafe.Offsetof(f.codeInitialAddress)), functionCodeInitialAddressOffset, "functionCodeInitialAddressOffset") 71 requireEqual(int(unsafe.Offsetof(f.moduleInstance)), functionModuleInstanceOffset, "functionModuleInstanceOffset") 72 requireEqual(int(unsafe.Offsetof(f.typeID)), functionTypeIDOffset, "functionTypeIDOffset") 73 requireEqual(int(unsafe.Sizeof(f)), functionSize, "functionModuleInstanceOffset") 74 75 // Offsets for wasm.ModuleInstance. 76 var moduleInstance wasm.ModuleInstance 77 requireEqual(int(unsafe.Offsetof(moduleInstance.Globals)), moduleInstanceGlobalsOffset, "moduleInstanceGlobalsOffset") 78 requireEqual(int(unsafe.Offsetof(moduleInstance.MemoryInstance)), moduleInstanceMemoryOffset, "moduleInstanceMemoryOffset") 79 requireEqual(int(unsafe.Offsetof(moduleInstance.Tables)), moduleInstanceTablesOffset, "moduleInstanceTablesOffset") 80 requireEqual(int(unsafe.Offsetof(moduleInstance.Engine)), moduleInstanceEngineOffset, "moduleInstanceEngineOffset") 81 requireEqual(int(unsafe.Offsetof(moduleInstance.TypeIDs)), moduleInstanceTypeIDsOffset, "moduleInstanceTypeIDsOffset") 82 requireEqual(int(unsafe.Offsetof(moduleInstance.DataInstances)), moduleInstanceDataInstancesOffset, "moduleInstanceDataInstancesOffset") 83 requireEqual(int(unsafe.Offsetof(moduleInstance.ElementInstances)), moduleInstanceElementInstancesOffset, "moduleInstanceElementInstancesOffset") 84 85 // Offsets for wasm.Table. 86 var tableInstance wasm.TableInstance 87 requireEqual(int(unsafe.Offsetof(tableInstance.References)), tableInstanceTableOffset, "tableInstanceTableOffset") 88 // We add "+8" to get the length of Tables[0].Table 89 // since the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory. 90 requireEqual(int(unsafe.Offsetof(tableInstance.References)+8), tableInstanceTableLenOffset, "tableInstanceTableLenOffset") 91 92 // Offsets for wasm.Memory 93 var memoryInstance wasm.MemoryInstance 94 requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)), memoryInstanceBufferOffset, "memoryInstanceBufferOffset") 95 // "+8" because the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory. 96 requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)+8), memoryInstanceBufferLenOffset, "memoryInstanceBufferLenOffset") 97 98 // Offsets for wasm.GlobalInstance 99 var globalInstance wasm.GlobalInstance 100 requireEqual(int(unsafe.Offsetof(globalInstance.Val)), globalInstanceValueOffset, "globalInstanceValueOffset") 101 102 var dataInstance wasm.DataInstance 103 requireEqual(int(unsafe.Sizeof(dataInstance)), dataInstanceStructSize, "dataInstanceStructSize") 104 105 var elementInstance wasm.ElementInstance 106 requireEqual(int(unsafe.Sizeof(elementInstance)), elementInstanceStructSize, "elementInstanceStructSize") 107 108 var pointer uintptr 109 requireEqual(int(unsafe.Sizeof(pointer)), 1<<pointerSizeLog2, "pointerSizeLog2") 110 } 111 112 type compilerEnv struct { 113 me *moduleEngine 114 ce *callEngine 115 moduleInstance *wasm.ModuleInstance 116 } 117 118 func (j *compilerEnv) stackTopAsUint32() uint32 { 119 return uint32(j.stack()[j.ce.stackContext.stackPointer-1]) 120 } 121 122 func (j *compilerEnv) stackTopAsInt32() int32 { 123 return int32(j.stack()[j.ce.stackContext.stackPointer-1]) 124 } 125 126 func (j *compilerEnv) stackTopAsUint64() uint64 { 127 return j.stack()[j.ce.stackContext.stackPointer-1] 128 } 129 130 func (j *compilerEnv) stackTopAsInt64() int64 { 131 return int64(j.stack()[j.ce.stackContext.stackPointer-1]) 132 } 133 134 func (j *compilerEnv) stackTopAsFloat32() float32 { 135 return math.Float32frombits(uint32(j.stack()[j.ce.stackContext.stackPointer-1])) 136 } 137 138 func (j *compilerEnv) stackTopAsFloat64() float64 { 139 return math.Float64frombits(j.stack()[j.ce.stackContext.stackPointer-1]) 140 } 141 142 func (j *compilerEnv) stackTopAsV128() (lo uint64, hi uint64) { 143 st := j.stack() 144 return st[j.ce.stackContext.stackPointer-2], st[j.ce.stackContext.stackPointer-1] 145 } 146 147 func (j *compilerEnv) memory() []byte { 148 return j.moduleInstance.MemoryInstance.Buffer 149 } 150 151 func (j *compilerEnv) stack() []uint64 { 152 return j.ce.stack 153 } 154 155 func (j *compilerEnv) compilerStatus() nativeCallStatusCode { 156 return j.ce.exitContext.statusCode 157 } 158 159 func (j *compilerEnv) builtinFunctionCallAddress() wasm.Index { 160 return j.ce.exitContext.builtinFunctionCallIndex 161 } 162 163 // stackPointer returns the stack pointer minus the call frame. 164 func (j *compilerEnv) stackPointer() uint64 { 165 return j.ce.stackContext.stackPointer - callFrameDataSizeInUint64 166 } 167 168 func (j *compilerEnv) stackBasePointer() uint64 { 169 return j.ce.stackContext.stackBasePointerInBytes >> 3 170 } 171 172 func (j *compilerEnv) setStackPointer(sp uint64) { 173 j.ce.stackContext.stackPointer = sp 174 } 175 176 func (j *compilerEnv) addGlobals(g ...*wasm.GlobalInstance) { 177 j.moduleInstance.Globals = append(j.moduleInstance.Globals, g...) 178 } 179 180 func (j *compilerEnv) globals() []*wasm.GlobalInstance { 181 return j.moduleInstance.Globals 182 } 183 184 func (j *compilerEnv) addTable(table *wasm.TableInstance) { 185 j.moduleInstance.Tables = append(j.moduleInstance.Tables, table) 186 } 187 188 func (j *compilerEnv) setStackBasePointer(sp uint64) { 189 j.ce.stackContext.stackBasePointerInBytes = sp << 3 190 } 191 192 func (j *compilerEnv) module() *wasm.ModuleInstance { 193 return j.moduleInstance 194 } 195 196 func (j *compilerEnv) moduleEngine() *moduleEngine { 197 return j.me 198 } 199 200 func (j *compilerEnv) callEngine() *callEngine { 201 return j.ce 202 } 203 204 func (j *compilerEnv) exec(machineCode []byte) { 205 cm := &compiledModule{compiledCode: &compiledCode{}} 206 if err := cm.executable.Map(len(machineCode)); err != nil { 207 panic(err) 208 } 209 executable := cm.executable.Bytes() 210 copy(executable, machineCode) 211 makeExecutable(executable) 212 213 f := &function{ 214 parent: &compiledFunction{parent: cm.compiledCode}, 215 codeInitialAddress: uintptr(unsafe.Pointer(&executable[0])), 216 moduleInstance: j.moduleInstance, 217 } 218 j.ce.initialFn = f 219 j.ce.fn = f 220 221 nativecall( 222 uintptr(unsafe.Pointer(&executable[0])), 223 j.ce, j.moduleInstance, 224 ) 225 } 226 227 func (j *compilerEnv) requireNewCompiler(t *testing.T, functionType *wasm.FunctionType, fn func() compiler, ir *wazeroir.CompilationResult) compilerImpl { 228 requireSupportedOSArch(t) 229 230 if ir == nil { 231 ir = &wazeroir.CompilationResult{ 232 LabelCallers: map[wazeroir.Label]uint32{}, 233 } 234 } 235 236 c := fn() 237 c.Init(functionType, ir, false) 238 239 ret, ok := c.(compilerImpl) 240 require.True(t, ok) 241 return ret 242 } 243 244 // compilerImpl is the interface used for architecture-independent unit tests in this pkg. 245 // This is currently implemented by amd64 and arm64. 246 type compilerImpl interface { 247 compiler 248 compileExitFromNativeCode(nativeCallStatusCode) 249 compileMaybeGrowStack() error 250 compileReturnFunction() error 251 assignStackPointerCeil(uint64) 252 setStackPointerCeil(uint64) 253 compileReleaseRegisterToStack(loc *runtimeValueLocation) 254 setRuntimeValueLocationStack(*runtimeValueLocationStack) 255 compileEnsureOnRegister(loc *runtimeValueLocation) error 256 compileModuleContextInitialization() error 257 } 258 259 const defaultMemoryPageNumInTest = 1 260 261 func newCompilerEnvironment() *compilerEnv { 262 me := &moduleEngine{} 263 return &compilerEnv{ 264 me: me, 265 moduleInstance: &wasm.ModuleInstance{ 266 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, wasm.MemoryPageSize*defaultMemoryPageNumInTest)}, 267 Tables: []*wasm.TableInstance{}, 268 Globals: []*wasm.GlobalInstance{}, 269 Engine: me, 270 }, 271 ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledCode{}}}), 272 } 273 } 274 275 // requireRuntimeLocationStackPointerEqual ensures that the compiler's runtimeValueLocationStack has 276 // the expected stack pointer value relative to the call frame. 277 func requireRuntimeLocationStackPointerEqual(t *testing.T, expSP uint64, c compiler) { 278 require.Equal(t, expSP, c.runtimeValueLocationStack().sp-callFrameDataSizeInUint64) 279 } 280 281 // TestCompileI32WrapFromI64 is the regression test for https://github.com/bananabytelabs/wazero/issues/1008 282 func TestCompileI32WrapFromI64(t *testing.T) { 283 c := newCompiler() 284 c.Init(&wasm.FunctionType{}, nil, false) 285 286 // Push the original i64 value. 287 loc := c.runtimeValueLocationStack().pushRuntimeValueLocationOnStack() 288 loc.valueType = runtimeValueTypeI64 289 // Wrap it as the i32, and this should result in having runtimeValueTypeI32 on top of the stack. 290 err := c.compileI32WrapFromI64() 291 require.NoError(t, err) 292 require.Equal(t, runtimeValueTypeI32, loc.valueType) 293 } 294 295 func operationPtr(operation wazeroir.UnionOperation) *wazeroir.UnionOperation { 296 return &operation 297 } 298 299 func requireExecutable(original []byte) (executable []byte) { 300 executable, err := platform.MmapCodeSegment(len(original)) 301 if err != nil { 302 panic(err) 303 } 304 copy(executable, original) 305 makeExecutable(executable) 306 return executable 307 } 308 309 func makeExecutable(executable []byte) { 310 if runtime.GOARCH == "arm64" { 311 if err := platform.MprotectRX(executable); err != nil { 312 panic(err) 313 } 314 } 315 }