wa-lang.org/wazero@v1.0.2/internal/engine/compiler/compiler_test.go (about) 1 package compiler 2 3 import ( 4 "fmt" 5 "math" 6 "os" 7 "testing" 8 "unsafe" 9 10 "wa-lang.org/wazero/internal/platform" 11 "wa-lang.org/wazero/internal/testing/require" 12 "wa-lang.org/wazero/internal/wasm" 13 "wa-lang.org/wazero/internal/wazeroir" 14 ) 15 16 func TestMain(m *testing.M) { 17 if !platform.CompilerSupported() { 18 os.Exit(0) 19 } 20 os.Exit(m.Run()) 21 } 22 23 // Ensures that the offset consts do not drift when we manipulate the target 24 // structs. 25 // 26 // Note: This is a package initializer as many tests could fail if these 27 // constants are misaligned, hiding the root cause. 28 func init() { 29 var me moduleEngine 30 requireEqual := func(expected, actual int, name string) { 31 if expected != actual { 32 panic(fmt.Sprintf("%s: expected %d, but was %d", name, expected, actual)) 33 } 34 } 35 requireEqual(int(unsafe.Offsetof(me.functions)), moduleEngineFunctionsOffset, "moduleEngineFunctionsOffset") 36 37 var ce callEngine 38 // Offsets for callEngine.moduleContext. 39 requireEqual(int(unsafe.Offsetof(ce.fn)), callEngineModuleContextFnOffset, "callEngineModuleContextFnOffset") 40 requireEqual(int(unsafe.Offsetof(ce.moduleInstanceAddress)), callEngineModuleContextModuleInstanceAddressOffset, "callEngineModuleContextModuleInstanceAddressOffset") 41 requireEqual(int(unsafe.Offsetof(ce.globalElement0Address)), callEngineModuleContextGlobalElement0AddressOffset, "callEngineModuleContextGlobalElement0AddressOffset") 42 requireEqual(int(unsafe.Offsetof(ce.memoryElement0Address)), callEngineModuleContextMemoryElement0AddressOffset, "callEngineModuleContextMemoryElement0AddressOffset") 43 requireEqual(int(unsafe.Offsetof(ce.memorySliceLen)), callEngineModuleContextMemorySliceLenOffset, "callEngineModuleContextMemorySliceLenOffset") 44 requireEqual(int(unsafe.Offsetof(ce.memoryInstance)), callEngineModuleContextMemoryInstanceOffset, "callEngineModuleContextMemoryInstanceOffset") 45 requireEqual(int(unsafe.Offsetof(ce.tablesElement0Address)), callEngineModuleContextTablesElement0AddressOffset, "callEngineModuleContextTablesElement0AddressOffset") 46 requireEqual(int(unsafe.Offsetof(ce.functionsElement0Address)), callEngineModuleContextFunctionsElement0AddressOffset, "callEngineModuleContextFunctionsElement0AddressOffset") 47 requireEqual(int(unsafe.Offsetof(ce.typeIDsElement0Address)), callEngineModuleContextTypeIDsElement0AddressOffset, "callEngineModuleContextTypeIDsElement0AddressOffset") 48 requireEqual(int(unsafe.Offsetof(ce.dataInstancesElement0Address)), callEngineModuleContextDataInstancesElement0AddressOffset, "callEngineModuleContextDataInstancesElement0AddressOffset") 49 requireEqual(int(unsafe.Offsetof(ce.elementInstancesElement0Address)), callEngineModuleContextElementInstancesElement0AddressOffset, "callEngineModuleContextElementInstancesElement0AddressOffset") 50 51 // Offsets for callEngine.stackContext 52 requireEqual(int(unsafe.Offsetof(ce.stackPointer)), callEngineStackContextStackPointerOffset, "callEngineStackContextStackPointerOffset") 53 requireEqual(int(unsafe.Offsetof(ce.stackBasePointerInBytes)), callEngineStackContextStackBasePointerInBytesOffset, "callEngineStackContextStackBasePointerInBytesOffset") 54 requireEqual(int(unsafe.Offsetof(ce.stackElement0Address)), callEngineStackContextStackElement0AddressOffset, "callEngineStackContextStackElement0AddressOffset") 55 requireEqual(int(unsafe.Offsetof(ce.stackLenInBytes)), callEngineStackContextStackLenInBytesOffset, "callEngineStackContextStackLenInBytesOffset") 56 57 // Offsets for callEngine.exitContext. 58 requireEqual(int(unsafe.Offsetof(ce.statusCode)), callEngineExitContextNativeCallStatusCodeOffset, "callEngineExitContextNativeCallStatusCodeOffset") 59 requireEqual(int(unsafe.Offsetof(ce.builtinFunctionCallIndex)), callEngineExitContextBuiltinFunctionCallIndexOffset, "callEngineExitContextBuiltinFunctionCallIndexOffset") 60 requireEqual(int(unsafe.Offsetof(ce.returnAddress)), callEngineExitContextReturnAddressOffset, "callEngineExitContextReturnAddressOffset") 61 62 // Size and offsets for callFrame. 63 var frame callFrame 64 requireEqual(int(unsafe.Sizeof(frame))/8, callFrameDataSizeInUint64, "callFrameDataSize") 65 66 // Offsets for code. 67 var compiledFunc function 68 requireEqual(int(unsafe.Offsetof(compiledFunc.codeInitialAddress)), functionCodeInitialAddressOffset, "functionCodeInitialAddressOffset") 69 requireEqual(int(unsafe.Offsetof(compiledFunc.source)), functionSourceOffset, "functionSourceOffset") 70 requireEqual(int(unsafe.Offsetof(compiledFunc.moduleInstanceAddress)), functionModuleInstanceAddressOffset, "functionModuleInstanceAddressOffset") 71 72 // Offsets for wasm.ModuleInstance. 73 var moduleInstance wasm.ModuleInstance 74 requireEqual(int(unsafe.Offsetof(moduleInstance.Globals)), moduleInstanceGlobalsOffset, "moduleInstanceGlobalsOffset") 75 requireEqual(int(unsafe.Offsetof(moduleInstance.Memory)), moduleInstanceMemoryOffset, "moduleInstanceMemoryOffset") 76 requireEqual(int(unsafe.Offsetof(moduleInstance.Tables)), moduleInstanceTablesOffset, "moduleInstanceTablesOffset") 77 requireEqual(int(unsafe.Offsetof(moduleInstance.Engine)), moduleInstanceEngineOffset, "moduleInstanceEngineOffset") 78 requireEqual(int(unsafe.Offsetof(moduleInstance.TypeIDs)), moduleInstanceTypeIDsOffset, "moduleInstanceTypeIDsOffset") 79 requireEqual(int(unsafe.Offsetof(moduleInstance.DataInstances)), moduleInstanceDataInstancesOffset, "moduleInstanceDataInstancesOffset") 80 requireEqual(int(unsafe.Offsetof(moduleInstance.ElementInstances)), moduleInstanceElementInstancesOffset, "moduleInstanceElementInstancesOffset") 81 82 var functionInstance wasm.FunctionInstance 83 requireEqual(int(unsafe.Offsetof(functionInstance.TypeID)), functionInstanceTypeIDOffset, "functionInstanceTypeIDOffset") 84 85 // Offsets for wasm.Table. 86 var tableInstance wasm.TableInstance 87 requireEqual(int(unsafe.Offsetof(tableInstance.References)), tableInstanceTableOffset, "tableInstanceTableOffset") 88 // We add "+8" to get the length of Tables[0].Table 89 // since the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory. 90 requireEqual(int(unsafe.Offsetof(tableInstance.References)+8), tableInstanceTableLenOffset, "tableInstanceTableLenOffset") 91 92 // Offsets for wasm.Memory 93 var memoryInstance wasm.MemoryInstance 94 requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)), memoryInstanceBufferOffset, "memoryInstanceBufferOffset") 95 // "+8" because the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory. 96 requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)+8), memoryInstanceBufferLenOffset, "memoryInstanceBufferLenOffset") 97 98 // Offsets for wasm.GlobalInstance 99 var globalInstance wasm.GlobalInstance 100 requireEqual(int(unsafe.Offsetof(globalInstance.Val)), globalInstanceValueOffset, "globalInstanceValueOffset") 101 102 var dataInstance wasm.DataInstance 103 requireEqual(int(unsafe.Sizeof(dataInstance)), dataInstanceStructSize, "dataInstanceStructSize") 104 105 var elementInstance wasm.ElementInstance 106 requireEqual(int(unsafe.Sizeof(elementInstance)), elementInstanceStructSize, "elementInstanceStructSize") 107 108 var pointer uintptr 109 requireEqual(int(unsafe.Sizeof(pointer)), 1<<pointerSizeLog2, "pointerSizeLog2") 110 } 111 112 type compilerEnv struct { 113 me *moduleEngine 114 ce *callEngine 115 moduleInstance *wasm.ModuleInstance 116 } 117 118 func (j *compilerEnv) stackTopAsUint32() uint32 { 119 return uint32(j.stack()[j.ce.stackContext.stackPointer-1]) 120 } 121 122 func (j *compilerEnv) stackTopAsInt32() int32 { 123 return int32(j.stack()[j.ce.stackContext.stackPointer-1]) 124 } 125 126 func (j *compilerEnv) stackTopAsUint64() uint64 { 127 return j.stack()[j.ce.stackContext.stackPointer-1] 128 } 129 130 func (j *compilerEnv) stackTopAsInt64() int64 { 131 return int64(j.stack()[j.ce.stackContext.stackPointer-1]) 132 } 133 134 func (j *compilerEnv) stackTopAsFloat32() float32 { 135 return math.Float32frombits(uint32(j.stack()[j.ce.stackContext.stackPointer-1])) 136 } 137 138 func (j *compilerEnv) stackTopAsFloat64() float64 { 139 return math.Float64frombits(j.stack()[j.ce.stackContext.stackPointer-1]) 140 } 141 142 func (j *compilerEnv) stackTopAsV128() (lo uint64, hi uint64) { 143 st := j.stack() 144 return st[j.ce.stackContext.stackPointer-2], st[j.ce.stackContext.stackPointer-1] 145 } 146 147 func (j *compilerEnv) memory() []byte { 148 return j.moduleInstance.Memory.Buffer 149 } 150 151 func (j *compilerEnv) stack() []uint64 { 152 return j.ce.stack 153 } 154 155 func (j *compilerEnv) compilerStatus() nativeCallStatusCode { 156 return j.ce.exitContext.statusCode 157 } 158 159 func (j *compilerEnv) builtinFunctionCallAddress() wasm.Index { 160 return j.ce.exitContext.builtinFunctionCallIndex 161 } 162 163 // stackPointer returns the stack pointer minus the call frame. 164 func (j *compilerEnv) stackPointer() uint64 { 165 return j.ce.stackContext.stackPointer - callFrameDataSizeInUint64 166 } 167 168 func (j *compilerEnv) stackBasePointer() uint64 { 169 return j.ce.stackContext.stackBasePointerInBytes >> 3 170 } 171 172 func (j *compilerEnv) setStackPointer(sp uint64) { 173 j.ce.stackContext.stackPointer = sp 174 } 175 176 func (j *compilerEnv) addGlobals(g ...*wasm.GlobalInstance) { 177 j.moduleInstance.Globals = append(j.moduleInstance.Globals, g...) 178 } 179 180 func (j *compilerEnv) globals() []*wasm.GlobalInstance { 181 return j.moduleInstance.Globals 182 } 183 184 func (j *compilerEnv) addTable(table *wasm.TableInstance) { 185 j.moduleInstance.Tables = append(j.moduleInstance.Tables, table) 186 } 187 188 func (j *compilerEnv) setStackBasePointer(sp uint64) { 189 j.ce.stackContext.stackBasePointerInBytes = sp << 3 190 } 191 192 func (j *compilerEnv) module() *wasm.ModuleInstance { 193 return j.moduleInstance 194 } 195 196 func (j *compilerEnv) moduleEngine() *moduleEngine { 197 return j.me 198 } 199 200 func (j *compilerEnv) callEngine() *callEngine { 201 return j.ce 202 } 203 204 func (j *compilerEnv) newFunction(codeSegment []byte) *function { 205 return &function{ 206 parent: &code{codeSegment: codeSegment}, 207 codeInitialAddress: uintptr(unsafe.Pointer(&codeSegment[0])), 208 moduleInstanceAddress: uintptr(unsafe.Pointer(j.moduleInstance)), 209 source: &wasm.FunctionInstance{ 210 Type: &wasm.FunctionType{}, 211 Module: j.moduleInstance, 212 }, 213 } 214 } 215 216 func (j *compilerEnv) exec(codeSegment []byte) { 217 f := j.newFunction(codeSegment) 218 j.ce.initialFn = f 219 j.ce.fn = f 220 221 nativecall( 222 uintptr(unsafe.Pointer(&codeSegment[0])), 223 uintptr(unsafe.Pointer(j.ce)), 224 uintptr(unsafe.Pointer(j.moduleInstance)), 225 ) 226 } 227 228 // newTestCompiler allows us to test a different architecture than the current one. 229 type newTestCompiler func(ir *wazeroir.CompilationResult, _ bool) (compiler, error) 230 231 func (j *compilerEnv) requireNewCompiler(t *testing.T, fn newTestCompiler, ir *wazeroir.CompilationResult) compilerImpl { 232 requireSupportedOSArch(t) 233 234 if ir == nil { 235 ir = &wazeroir.CompilationResult{ 236 LabelCallers: map[string]uint32{}, 237 Signature: &wasm.FunctionType{}, 238 } 239 } 240 c, err := fn(ir, false) 241 242 require.NoError(t, err) 243 244 ret, ok := c.(compilerImpl) 245 require.True(t, ok) 246 return ret 247 } 248 249 // CompilerImpl is the interface used for architecture-independent unit tests in this pkg. 250 // This is currently implemented by amd64 and arm64. 251 type compilerImpl interface { 252 compiler 253 compileExitFromNativeCode(nativeCallStatusCode) 254 compileMaybeGrowStack() error 255 compileReturnFunction() error 256 getOnStackPointerCeilDeterminedCallBack() func(uint64) 257 setStackPointerCeil(uint64) 258 compileReleaseRegisterToStack(loc *runtimeValueLocation) 259 setRuntimeValueLocationStack(*runtimeValueLocationStack) 260 compileEnsureOnRegister(loc *runtimeValueLocation) error 261 compileModuleContextInitialization() error 262 } 263 264 const defaultMemoryPageNumInTest = 1 265 266 func newCompilerEnvironment() *compilerEnv { 267 me := &moduleEngine{} 268 return &compilerEnv{ 269 me: me, 270 moduleInstance: &wasm.ModuleInstance{ 271 Memory: &wasm.MemoryInstance{Buffer: make([]byte, wasm.MemoryPageSize*defaultMemoryPageNumInTest)}, 272 Tables: []*wasm.TableInstance{}, 273 Globals: []*wasm.GlobalInstance{}, 274 Engine: me, 275 }, 276 ce: me.newCallEngine(initialStackSize, nil), 277 } 278 } 279 280 // requireRuntimeLocationStackPointerEqual ensures that the compiler's runtimeValueLocationStack has 281 // the expected stack pointer value relative to the call frame. 282 func requireRuntimeLocationStackPointerEqual(t *testing.T, expSP uint64, c compiler) { 283 require.Equal(t, expSP, c.runtimeValueLocationStack().sp-callFrameDataSizeInUint64) 284 }