github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/internal/engine/compiler/compiler_test.go (about) 1 package compiler 2 3 import ( 4 "fmt" 5 "math" 6 "os" 7 "runtime" 8 "testing" 9 "unsafe" 10 11 "github.com/wasilibs/wazerox/internal/platform" 12 "github.com/wasilibs/wazerox/internal/testing/require" 13 "github.com/wasilibs/wazerox/internal/wasm" 14 "github.com/wasilibs/wazerox/internal/wazeroir" 15 ) 16 17 func TestMain(m *testing.M) { 18 if !platform.CompilerSupported() { 19 os.Exit(0) 20 } 21 os.Exit(m.Run()) 22 } 23 24 // Ensures that the offset consts do not drift when we manipulate the target 25 // structs. 26 // 27 // Note: This is a package initializer as many tests could fail if these 28 // constants are misaligned, hiding the root cause. 29 func init() { 30 var me moduleEngine 31 requireEqual := func(expected, actual int, name string) { 32 if expected != actual { 33 panic(fmt.Sprintf("%s: expected %d, but was %d", name, expected, actual)) 34 } 35 } 36 requireEqual(int(unsafe.Offsetof(me.functions)), moduleEngineFunctionsOffset, "moduleEngineFunctionsOffset") 37 38 var ce callEngine 39 // Offsets for callEngine.moduleContext. 40 requireEqual(int(unsafe.Offsetof(ce.fn)), callEngineModuleContextFnOffset, "callEngineModuleContextFnOffset") 41 requireEqual(int(unsafe.Offsetof(ce.moduleInstance)), callEngineModuleContextModuleInstanceOffset, "callEngineModuleContextModuleInstanceOffset") 42 requireEqual(int(unsafe.Offsetof(ce.globalElement0Address)), callEngineModuleContextGlobalElement0AddressOffset, "callEngineModuleContextGlobalElement0AddressOffset") 43 requireEqual(int(unsafe.Offsetof(ce.memoryElement0Address)), callEngineModuleContextMemoryElement0AddressOffset, "callEngineModuleContextMemoryElement0AddressOffset") 44 requireEqual(int(unsafe.Offsetof(ce.memoryInstance)), callEngineModuleContextMemoryInstanceOffset, "callEngineModuleContextMemoryInstanceOffset") 45 requireEqual(int(unsafe.Offsetof(ce.tablesElement0Address)), callEngineModuleContextTablesElement0AddressOffset, "callEngineModuleContextTablesElement0AddressOffset") 46 requireEqual(int(unsafe.Offsetof(ce.functionsElement0Address)), callEngineModuleContextFunctionsElement0AddressOffset, "callEngineModuleContextFunctionsElement0AddressOffset") 47 requireEqual(int(unsafe.Offsetof(ce.typeIDsElement0Address)), callEngineModuleContextTypeIDsElement0AddressOffset, "callEngineModuleContextTypeIDsElement0AddressOffset") 48 requireEqual(int(unsafe.Offsetof(ce.dataInstancesElement0Address)), callEngineModuleContextDataInstancesElement0AddressOffset, "callEngineModuleContextDataInstancesElement0AddressOffset") 49 requireEqual(int(unsafe.Offsetof(ce.elementInstancesElement0Address)), callEngineModuleContextElementInstancesElement0AddressOffset, "callEngineModuleContextElementInstancesElement0AddressOffset") 50 51 // Offsets for callEngine.stackContext 52 requireEqual(int(unsafe.Offsetof(ce.stackPointer)), callEngineStackContextStackPointerOffset, "callEngineStackContextStackPointerOffset") 53 requireEqual(int(unsafe.Offsetof(ce.stackBasePointerInBytes)), callEngineStackContextStackBasePointerInBytesOffset, "callEngineStackContextStackBasePointerInBytesOffset") 54 requireEqual(int(unsafe.Offsetof(ce.stackElement0Address)), callEngineStackContextStackElement0AddressOffset, "callEngineStackContextStackElement0AddressOffset") 55 requireEqual(int(unsafe.Offsetof(ce.stackLenInBytes)), callEngineStackContextStackLenInBytesOffset, "callEngineStackContextStackLenInBytesOffset") 56 57 // Offsets for callEngine.exitContext. 58 requireEqual(int(unsafe.Offsetof(ce.statusCode)), callEngineExitContextNativeCallStatusCodeOffset, "callEngineExitContextNativeCallStatusCodeOffset") 59 requireEqual(int(unsafe.Offsetof(ce.builtinFunctionCallIndex)), callEngineExitContextBuiltinFunctionCallIndexOffset, "callEngineExitContextBuiltinFunctionCallIndexOffset") 60 requireEqual(int(unsafe.Offsetof(ce.returnAddress)), callEngineExitContextReturnAddressOffset, "callEngineExitContextReturnAddressOffset") 61 requireEqual(int(unsafe.Offsetof(ce.callerModuleInstance)), callEngineExitContextCallerModuleInstanceOffset, "callEngineExitContextCallerModuleInstanceOffset") 62 63 // Size and offsets for callFrame. 64 var frame callFrame 65 requireEqual(int(unsafe.Sizeof(frame))/8, callFrameDataSizeInUint64, "callFrameDataSize") 66 67 // Offsets for code. 68 var f function 69 requireEqual(int(unsafe.Offsetof(f.codeInitialAddress)), functionCodeInitialAddressOffset, "functionCodeInitialAddressOffset") 70 requireEqual(int(unsafe.Offsetof(f.moduleInstance)), functionModuleInstanceOffset, "functionModuleInstanceOffset") 71 requireEqual(int(unsafe.Offsetof(f.typeID)), functionTypeIDOffset, "functionTypeIDOffset") 72 requireEqual(int(unsafe.Sizeof(f)), functionSize, "functionModuleInstanceOffset") 73 74 // Offsets for wasm.ModuleInstance. 75 var moduleInstance wasm.ModuleInstance 76 requireEqual(int(unsafe.Offsetof(moduleInstance.Globals)), moduleInstanceGlobalsOffset, "moduleInstanceGlobalsOffset") 77 requireEqual(int(unsafe.Offsetof(moduleInstance.MemoryInstance)), moduleInstanceMemoryOffset, "moduleInstanceMemoryOffset") 78 requireEqual(int(unsafe.Offsetof(moduleInstance.Tables)), moduleInstanceTablesOffset, "moduleInstanceTablesOffset") 79 requireEqual(int(unsafe.Offsetof(moduleInstance.Engine)), moduleInstanceEngineOffset, "moduleInstanceEngineOffset") 80 requireEqual(int(unsafe.Offsetof(moduleInstance.TypeIDs)), moduleInstanceTypeIDsOffset, "moduleInstanceTypeIDsOffset") 81 requireEqual(int(unsafe.Offsetof(moduleInstance.DataInstances)), moduleInstanceDataInstancesOffset, "moduleInstanceDataInstancesOffset") 82 requireEqual(int(unsafe.Offsetof(moduleInstance.ElementInstances)), moduleInstanceElementInstancesOffset, "moduleInstanceElementInstancesOffset") 83 84 // Offsets for wasm.Table. 85 var tableInstance wasm.TableInstance 86 requireEqual(int(unsafe.Offsetof(tableInstance.References)), tableInstanceTableOffset, "tableInstanceTableOffset") 87 // We add "+8" to get the length of Tables[0].Table 88 // since the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory. 89 requireEqual(int(unsafe.Offsetof(tableInstance.References)+8), tableInstanceTableLenOffset, "tableInstanceTableLenOffset") 90 91 // Offsets for wasm.Memory 92 var memoryInstance wasm.MemoryInstance 93 requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)), memoryInstanceBufferOffset, "memoryInstanceBufferOffset") 94 // "+8" because the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory. 95 requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)+8), memoryInstanceBufferLenOffset, "memoryInstanceBufferLenOffset") 96 97 // Offsets for wasm.GlobalInstance 98 var globalInstance wasm.GlobalInstance 99 requireEqual(int(unsafe.Offsetof(globalInstance.Val)), globalInstanceValueOffset, "globalInstanceValueOffset") 100 101 var dataInstance wasm.DataInstance 102 requireEqual(int(unsafe.Sizeof(dataInstance)), dataInstanceStructSize, "dataInstanceStructSize") 103 104 var elementInstance wasm.ElementInstance 105 requireEqual(int(unsafe.Sizeof(elementInstance)), elementInstanceStructSize, "elementInstanceStructSize") 106 107 var pointer uintptr 108 requireEqual(int(unsafe.Sizeof(pointer)), 1<<pointerSizeLog2, "pointerSizeLog2") 109 } 110 111 type compilerEnv struct { 112 me *moduleEngine 113 ce *callEngine 114 moduleInstance *wasm.ModuleInstance 115 } 116 117 func (j *compilerEnv) stackTopAsUint32() uint32 { 118 return uint32(j.stack()[j.ce.stackContext.stackPointer-1]) 119 } 120 121 func (j *compilerEnv) stackTopAsInt32() int32 { 122 return int32(j.stack()[j.ce.stackContext.stackPointer-1]) 123 } 124 125 func (j *compilerEnv) stackTopAsUint64() uint64 { 126 return j.stack()[j.ce.stackContext.stackPointer-1] 127 } 128 129 func (j *compilerEnv) stackTopAsInt64() int64 { 130 return int64(j.stack()[j.ce.stackContext.stackPointer-1]) 131 } 132 133 func (j *compilerEnv) stackTopAsFloat32() float32 { 134 return math.Float32frombits(uint32(j.stack()[j.ce.stackContext.stackPointer-1])) 135 } 136 137 func (j *compilerEnv) stackTopAsFloat64() float64 { 138 return math.Float64frombits(j.stack()[j.ce.stackContext.stackPointer-1]) 139 } 140 141 func (j *compilerEnv) stackTopAsV128() (lo uint64, hi uint64) { 142 st := j.stack() 143 return st[j.ce.stackContext.stackPointer-2], st[j.ce.stackContext.stackPointer-1] 144 } 145 146 func (j *compilerEnv) memory() []byte { 147 return j.moduleInstance.MemoryInstance.Buffer 148 } 149 150 func (j *compilerEnv) stack() []uint64 { 151 return j.ce.stack 152 } 153 154 func (j *compilerEnv) compilerStatus() nativeCallStatusCode { 155 return j.ce.exitContext.statusCode 156 } 157 158 func (j *compilerEnv) builtinFunctionCallAddress() wasm.Index { 159 return j.ce.exitContext.builtinFunctionCallIndex 160 } 161 162 // stackPointer returns the stack pointer minus the call frame. 163 func (j *compilerEnv) stackPointer() uint64 { 164 return j.ce.stackContext.stackPointer - callFrameDataSizeInUint64 165 } 166 167 func (j *compilerEnv) stackBasePointer() uint64 { 168 return j.ce.stackContext.stackBasePointerInBytes >> 3 169 } 170 171 func (j *compilerEnv) setStackPointer(sp uint64) { 172 j.ce.stackContext.stackPointer = sp 173 } 174 175 func (j *compilerEnv) addGlobals(g ...*wasm.GlobalInstance) { 176 j.moduleInstance.Globals = append(j.moduleInstance.Globals, g...) 177 } 178 179 func (j *compilerEnv) globals() []*wasm.GlobalInstance { 180 return j.moduleInstance.Globals 181 } 182 183 func (j *compilerEnv) addTable(table *wasm.TableInstance) { 184 j.moduleInstance.Tables = append(j.moduleInstance.Tables, table) 185 } 186 187 func (j *compilerEnv) setStackBasePointer(sp uint64) { 188 j.ce.stackContext.stackBasePointerInBytes = sp << 3 189 } 190 191 func (j *compilerEnv) module() *wasm.ModuleInstance { 192 return j.moduleInstance 193 } 194 195 func (j *compilerEnv) moduleEngine() *moduleEngine { 196 return j.me 197 } 198 199 func (j *compilerEnv) callEngine() *callEngine { 200 return j.ce 201 } 202 203 func (j *compilerEnv) exec(machineCode []byte) { 204 cm := &compiledModule{compiledCode: &compiledCode{}} 205 if err := cm.executable.Map(len(machineCode)); err != nil { 206 panic(err) 207 } 208 executable := cm.executable.Bytes() 209 copy(executable, machineCode) 210 makeExecutable(executable) 211 212 f := &function{ 213 parent: &compiledFunction{parent: cm.compiledCode}, 214 codeInitialAddress: uintptr(unsafe.Pointer(&executable[0])), 215 moduleInstance: j.moduleInstance, 216 } 217 j.ce.initialFn = f 218 j.ce.fn = f 219 220 nativecall( 221 uintptr(unsafe.Pointer(&executable[0])), 222 j.ce, j.moduleInstance, 223 ) 224 } 225 226 func (j *compilerEnv) requireNewCompiler(t *testing.T, functionType *wasm.FunctionType, fn func() compiler, ir *wazeroir.CompilationResult) compilerImpl { 227 requireSupportedOSArch(t) 228 229 if ir == nil { 230 ir = &wazeroir.CompilationResult{ 231 LabelCallers: map[wazeroir.Label]uint32{}, 232 } 233 if j.moduleInstance.MemoryInstance != nil { 234 ir.HasMemory = true 235 } 236 } 237 238 c := fn() 239 c.Init(functionType, ir, false) 240 241 ret, ok := c.(compilerImpl) 242 require.True(t, ok) 243 return ret 244 } 245 246 // compilerImpl is the interface used for architecture-independent unit tests in this pkg. 247 // This is currently implemented by amd64 and arm64. 248 type compilerImpl interface { 249 compiler 250 compileExitFromNativeCode(nativeCallStatusCode) 251 compileMaybeGrowStack() error 252 compileReturnFunction() error 253 assignStackPointerCeil(uint64) 254 setStackPointerCeil(uint64) 255 compileReleaseRegisterToStack(loc *runtimeValueLocation) 256 setRuntimeValueLocationStack(*runtimeValueLocationStack) 257 compileEnsureOnRegister(loc *runtimeValueLocation) error 258 compileModuleContextInitialization() error 259 } 260 261 const defaultMemoryPageNumInTest = 1 262 263 func newCompilerEnvironment() *compilerEnv { 264 me := &moduleEngine{} 265 return &compilerEnv{ 266 me: me, 267 moduleInstance: &wasm.ModuleInstance{ 268 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, wasm.MemoryPageSize*defaultMemoryPageNumInTest)}, 269 Tables: []*wasm.TableInstance{}, 270 Globals: []*wasm.GlobalInstance{}, 271 Engine: me, 272 }, 273 ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledCode{}}}), 274 } 275 } 276 277 // requireRuntimeLocationStackPointerEqual ensures that the compiler's runtimeValueLocationStack has 278 // the expected stack pointer value relative to the call frame. 279 func requireRuntimeLocationStackPointerEqual(t *testing.T, expSP uint64, c compiler) { 280 require.Equal(t, expSP, c.runtimeValueLocationStack().sp-callFrameDataSizeInUint64) 281 } 282 283 // TestCompileI32WrapFromI64 is the regression test for https://github.com/tetratelabs/wazero/issues/1008 284 func TestCompileI32WrapFromI64(t *testing.T) { 285 c := newCompiler() 286 c.Init(&wasm.FunctionType{}, nil, false) 287 288 // Push the original i64 value. 289 loc := c.runtimeValueLocationStack().pushRuntimeValueLocationOnStack() 290 loc.valueType = runtimeValueTypeI64 291 // Wrap it as the i32, and this should result in having runtimeValueTypeI32 on top of the stack. 292 err := c.compileI32WrapFromI64() 293 require.NoError(t, err) 294 require.Equal(t, runtimeValueTypeI32, loc.valueType) 295 } 296 297 func operationPtr(operation wazeroir.UnionOperation) *wazeroir.UnionOperation { 298 return &operation 299 } 300 301 func requireExecutable(original []byte) (executable []byte) { 302 executable, err := platform.MmapCodeSegment(len(original)) 303 if err != nil { 304 panic(err) 305 } 306 copy(executable, original) 307 makeExecutable(executable) 308 return executable 309 } 310 311 func makeExecutable(executable []byte) { 312 if runtime.GOARCH == "arm64" { 313 if err := platform.MprotectRX(executable); err != nil { 314 panic(err) 315 } 316 } 317 }