github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/internal/engine/compiler/compiler_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"os"
     7  	"runtime"
     8  	"testing"
     9  	"unsafe"
    10  
    11  	"github.com/wasilibs/wazerox/internal/platform"
    12  	"github.com/wasilibs/wazerox/internal/testing/require"
    13  	"github.com/wasilibs/wazerox/internal/wasm"
    14  	"github.com/wasilibs/wazerox/internal/wazeroir"
    15  )
    16  
    17  func TestMain(m *testing.M) {
    18  	if !platform.CompilerSupported() {
    19  		os.Exit(0)
    20  	}
    21  	os.Exit(m.Run())
    22  }
    23  
    24  // Ensures that the offset consts do not drift when we manipulate the target
    25  // structs.
    26  //
    27  // Note: This is a package initializer as many tests could fail if these
    28  // constants are misaligned, hiding the root cause.
    29  func init() {
    30  	var me moduleEngine
    31  	requireEqual := func(expected, actual int, name string) {
    32  		if expected != actual {
    33  			panic(fmt.Sprintf("%s: expected %d, but was %d", name, expected, actual))
    34  		}
    35  	}
    36  	requireEqual(int(unsafe.Offsetof(me.functions)), moduleEngineFunctionsOffset, "moduleEngineFunctionsOffset")
    37  
    38  	var ce callEngine
    39  	// Offsets for callEngine.moduleContext.
    40  	requireEqual(int(unsafe.Offsetof(ce.fn)), callEngineModuleContextFnOffset, "callEngineModuleContextFnOffset")
    41  	requireEqual(int(unsafe.Offsetof(ce.moduleInstance)), callEngineModuleContextModuleInstanceOffset, "callEngineModuleContextModuleInstanceOffset")
    42  	requireEqual(int(unsafe.Offsetof(ce.globalElement0Address)), callEngineModuleContextGlobalElement0AddressOffset, "callEngineModuleContextGlobalElement0AddressOffset")
    43  	requireEqual(int(unsafe.Offsetof(ce.memoryElement0Address)), callEngineModuleContextMemoryElement0AddressOffset, "callEngineModuleContextMemoryElement0AddressOffset")
    44  	requireEqual(int(unsafe.Offsetof(ce.memoryInstance)), callEngineModuleContextMemoryInstanceOffset, "callEngineModuleContextMemoryInstanceOffset")
    45  	requireEqual(int(unsafe.Offsetof(ce.tablesElement0Address)), callEngineModuleContextTablesElement0AddressOffset, "callEngineModuleContextTablesElement0AddressOffset")
    46  	requireEqual(int(unsafe.Offsetof(ce.functionsElement0Address)), callEngineModuleContextFunctionsElement0AddressOffset, "callEngineModuleContextFunctionsElement0AddressOffset")
    47  	requireEqual(int(unsafe.Offsetof(ce.typeIDsElement0Address)), callEngineModuleContextTypeIDsElement0AddressOffset, "callEngineModuleContextTypeIDsElement0AddressOffset")
    48  	requireEqual(int(unsafe.Offsetof(ce.dataInstancesElement0Address)), callEngineModuleContextDataInstancesElement0AddressOffset, "callEngineModuleContextDataInstancesElement0AddressOffset")
    49  	requireEqual(int(unsafe.Offsetof(ce.elementInstancesElement0Address)), callEngineModuleContextElementInstancesElement0AddressOffset, "callEngineModuleContextElementInstancesElement0AddressOffset")
    50  
    51  	// Offsets for callEngine.stackContext
    52  	requireEqual(int(unsafe.Offsetof(ce.stackPointer)), callEngineStackContextStackPointerOffset, "callEngineStackContextStackPointerOffset")
    53  	requireEqual(int(unsafe.Offsetof(ce.stackBasePointerInBytes)), callEngineStackContextStackBasePointerInBytesOffset, "callEngineStackContextStackBasePointerInBytesOffset")
    54  	requireEqual(int(unsafe.Offsetof(ce.stackElement0Address)), callEngineStackContextStackElement0AddressOffset, "callEngineStackContextStackElement0AddressOffset")
    55  	requireEqual(int(unsafe.Offsetof(ce.stackLenInBytes)), callEngineStackContextStackLenInBytesOffset, "callEngineStackContextStackLenInBytesOffset")
    56  
    57  	// Offsets for callEngine.exitContext.
    58  	requireEqual(int(unsafe.Offsetof(ce.statusCode)), callEngineExitContextNativeCallStatusCodeOffset, "callEngineExitContextNativeCallStatusCodeOffset")
    59  	requireEqual(int(unsafe.Offsetof(ce.builtinFunctionCallIndex)), callEngineExitContextBuiltinFunctionCallIndexOffset, "callEngineExitContextBuiltinFunctionCallIndexOffset")
    60  	requireEqual(int(unsafe.Offsetof(ce.returnAddress)), callEngineExitContextReturnAddressOffset, "callEngineExitContextReturnAddressOffset")
    61  	requireEqual(int(unsafe.Offsetof(ce.callerModuleInstance)), callEngineExitContextCallerModuleInstanceOffset, "callEngineExitContextCallerModuleInstanceOffset")
    62  
    63  	// Size and offsets for callFrame.
    64  	var frame callFrame
    65  	requireEqual(int(unsafe.Sizeof(frame))/8, callFrameDataSizeInUint64, "callFrameDataSize")
    66  
    67  	// Offsets for code.
    68  	var f function
    69  	requireEqual(int(unsafe.Offsetof(f.codeInitialAddress)), functionCodeInitialAddressOffset, "functionCodeInitialAddressOffset")
    70  	requireEqual(int(unsafe.Offsetof(f.moduleInstance)), functionModuleInstanceOffset, "functionModuleInstanceOffset")
    71  	requireEqual(int(unsafe.Offsetof(f.typeID)), functionTypeIDOffset, "functionTypeIDOffset")
    72  	requireEqual(int(unsafe.Sizeof(f)), functionSize, "functionModuleInstanceOffset")
    73  
    74  	// Offsets for wasm.ModuleInstance.
    75  	var moduleInstance wasm.ModuleInstance
    76  	requireEqual(int(unsafe.Offsetof(moduleInstance.Globals)), moduleInstanceGlobalsOffset, "moduleInstanceGlobalsOffset")
    77  	requireEqual(int(unsafe.Offsetof(moduleInstance.MemoryInstance)), moduleInstanceMemoryOffset, "moduleInstanceMemoryOffset")
    78  	requireEqual(int(unsafe.Offsetof(moduleInstance.Tables)), moduleInstanceTablesOffset, "moduleInstanceTablesOffset")
    79  	requireEqual(int(unsafe.Offsetof(moduleInstance.Engine)), moduleInstanceEngineOffset, "moduleInstanceEngineOffset")
    80  	requireEqual(int(unsafe.Offsetof(moduleInstance.TypeIDs)), moduleInstanceTypeIDsOffset, "moduleInstanceTypeIDsOffset")
    81  	requireEqual(int(unsafe.Offsetof(moduleInstance.DataInstances)), moduleInstanceDataInstancesOffset, "moduleInstanceDataInstancesOffset")
    82  	requireEqual(int(unsafe.Offsetof(moduleInstance.ElementInstances)), moduleInstanceElementInstancesOffset, "moduleInstanceElementInstancesOffset")
    83  
    84  	// Offsets for wasm.Table.
    85  	var tableInstance wasm.TableInstance
    86  	requireEqual(int(unsafe.Offsetof(tableInstance.References)), tableInstanceTableOffset, "tableInstanceTableOffset")
    87  	// We add "+8" to get the length of Tables[0].Table
    88  	// since the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory.
    89  	requireEqual(int(unsafe.Offsetof(tableInstance.References)+8), tableInstanceTableLenOffset, "tableInstanceTableLenOffset")
    90  
    91  	// Offsets for wasm.Memory
    92  	var memoryInstance wasm.MemoryInstance
    93  	requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)), memoryInstanceBufferOffset, "memoryInstanceBufferOffset")
    94  	// "+8" because the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory.
    95  	requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)+8), memoryInstanceBufferLenOffset, "memoryInstanceBufferLenOffset")
    96  
    97  	// Offsets for wasm.GlobalInstance
    98  	var globalInstance wasm.GlobalInstance
    99  	requireEqual(int(unsafe.Offsetof(globalInstance.Val)), globalInstanceValueOffset, "globalInstanceValueOffset")
   100  
   101  	var dataInstance wasm.DataInstance
   102  	requireEqual(int(unsafe.Sizeof(dataInstance)), dataInstanceStructSize, "dataInstanceStructSize")
   103  
   104  	var elementInstance wasm.ElementInstance
   105  	requireEqual(int(unsafe.Sizeof(elementInstance)), elementInstanceStructSize, "elementInstanceStructSize")
   106  
   107  	var pointer uintptr
   108  	requireEqual(int(unsafe.Sizeof(pointer)), 1<<pointerSizeLog2, "pointerSizeLog2")
   109  }
   110  
   111  type compilerEnv struct {
   112  	me             *moduleEngine
   113  	ce             *callEngine
   114  	moduleInstance *wasm.ModuleInstance
   115  }
   116  
   117  func (j *compilerEnv) stackTopAsUint32() uint32 {
   118  	return uint32(j.stack()[j.ce.stackContext.stackPointer-1])
   119  }
   120  
   121  func (j *compilerEnv) stackTopAsInt32() int32 {
   122  	return int32(j.stack()[j.ce.stackContext.stackPointer-1])
   123  }
   124  
   125  func (j *compilerEnv) stackTopAsUint64() uint64 {
   126  	return j.stack()[j.ce.stackContext.stackPointer-1]
   127  }
   128  
   129  func (j *compilerEnv) stackTopAsInt64() int64 {
   130  	return int64(j.stack()[j.ce.stackContext.stackPointer-1])
   131  }
   132  
   133  func (j *compilerEnv) stackTopAsFloat32() float32 {
   134  	return math.Float32frombits(uint32(j.stack()[j.ce.stackContext.stackPointer-1]))
   135  }
   136  
   137  func (j *compilerEnv) stackTopAsFloat64() float64 {
   138  	return math.Float64frombits(j.stack()[j.ce.stackContext.stackPointer-1])
   139  }
   140  
   141  func (j *compilerEnv) stackTopAsV128() (lo uint64, hi uint64) {
   142  	st := j.stack()
   143  	return st[j.ce.stackContext.stackPointer-2], st[j.ce.stackContext.stackPointer-1]
   144  }
   145  
   146  func (j *compilerEnv) memory() []byte {
   147  	return j.moduleInstance.MemoryInstance.Buffer
   148  }
   149  
   150  func (j *compilerEnv) stack() []uint64 {
   151  	return j.ce.stack
   152  }
   153  
   154  func (j *compilerEnv) compilerStatus() nativeCallStatusCode {
   155  	return j.ce.exitContext.statusCode
   156  }
   157  
   158  func (j *compilerEnv) builtinFunctionCallAddress() wasm.Index {
   159  	return j.ce.exitContext.builtinFunctionCallIndex
   160  }
   161  
   162  // stackPointer returns the stack pointer minus the call frame.
   163  func (j *compilerEnv) stackPointer() uint64 {
   164  	return j.ce.stackContext.stackPointer - callFrameDataSizeInUint64
   165  }
   166  
   167  func (j *compilerEnv) stackBasePointer() uint64 {
   168  	return j.ce.stackContext.stackBasePointerInBytes >> 3
   169  }
   170  
   171  func (j *compilerEnv) setStackPointer(sp uint64) {
   172  	j.ce.stackContext.stackPointer = sp
   173  }
   174  
   175  func (j *compilerEnv) addGlobals(g ...*wasm.GlobalInstance) {
   176  	j.moduleInstance.Globals = append(j.moduleInstance.Globals, g...)
   177  }
   178  
   179  func (j *compilerEnv) globals() []*wasm.GlobalInstance {
   180  	return j.moduleInstance.Globals
   181  }
   182  
   183  func (j *compilerEnv) addTable(table *wasm.TableInstance) {
   184  	j.moduleInstance.Tables = append(j.moduleInstance.Tables, table)
   185  }
   186  
   187  func (j *compilerEnv) setStackBasePointer(sp uint64) {
   188  	j.ce.stackContext.stackBasePointerInBytes = sp << 3
   189  }
   190  
   191  func (j *compilerEnv) module() *wasm.ModuleInstance {
   192  	return j.moduleInstance
   193  }
   194  
   195  func (j *compilerEnv) moduleEngine() *moduleEngine {
   196  	return j.me
   197  }
   198  
   199  func (j *compilerEnv) callEngine() *callEngine {
   200  	return j.ce
   201  }
   202  
   203  func (j *compilerEnv) exec(machineCode []byte) {
   204  	cm := &compiledModule{compiledCode: &compiledCode{}}
   205  	if err := cm.executable.Map(len(machineCode)); err != nil {
   206  		panic(err)
   207  	}
   208  	executable := cm.executable.Bytes()
   209  	copy(executable, machineCode)
   210  	makeExecutable(executable)
   211  
   212  	f := &function{
   213  		parent:             &compiledFunction{parent: cm.compiledCode},
   214  		codeInitialAddress: uintptr(unsafe.Pointer(&executable[0])),
   215  		moduleInstance:     j.moduleInstance,
   216  	}
   217  	j.ce.initialFn = f
   218  	j.ce.fn = f
   219  
   220  	nativecall(
   221  		uintptr(unsafe.Pointer(&executable[0])),
   222  		j.ce, j.moduleInstance,
   223  	)
   224  }
   225  
   226  func (j *compilerEnv) requireNewCompiler(t *testing.T, functionType *wasm.FunctionType, fn func() compiler, ir *wazeroir.CompilationResult) compilerImpl {
   227  	requireSupportedOSArch(t)
   228  
   229  	if ir == nil {
   230  		ir = &wazeroir.CompilationResult{
   231  			LabelCallers: map[wazeroir.Label]uint32{},
   232  		}
   233  		if j.moduleInstance.MemoryInstance != nil {
   234  			ir.HasMemory = true
   235  		}
   236  	}
   237  
   238  	c := fn()
   239  	c.Init(functionType, ir, false)
   240  
   241  	ret, ok := c.(compilerImpl)
   242  	require.True(t, ok)
   243  	return ret
   244  }
   245  
   246  // compilerImpl is the interface used for architecture-independent unit tests in this pkg.
   247  // This is currently implemented by amd64 and arm64.
   248  type compilerImpl interface {
   249  	compiler
   250  	compileExitFromNativeCode(nativeCallStatusCode)
   251  	compileMaybeGrowStack() error
   252  	compileReturnFunction() error
   253  	assignStackPointerCeil(uint64)
   254  	setStackPointerCeil(uint64)
   255  	compileReleaseRegisterToStack(loc *runtimeValueLocation)
   256  	setRuntimeValueLocationStack(*runtimeValueLocationStack)
   257  	compileEnsureOnRegister(loc *runtimeValueLocation) error
   258  	compileModuleContextInitialization() error
   259  }
   260  
   261  const defaultMemoryPageNumInTest = 1
   262  
   263  func newCompilerEnvironment() *compilerEnv {
   264  	me := &moduleEngine{}
   265  	return &compilerEnv{
   266  		me: me,
   267  		moduleInstance: &wasm.ModuleInstance{
   268  			MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, wasm.MemoryPageSize*defaultMemoryPageNumInTest)},
   269  			Tables:         []*wasm.TableInstance{},
   270  			Globals:        []*wasm.GlobalInstance{},
   271  			Engine:         me,
   272  		},
   273  		ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledCode{}}}),
   274  	}
   275  }
   276  
   277  // requireRuntimeLocationStackPointerEqual ensures that the compiler's runtimeValueLocationStack has
   278  // the expected stack pointer value relative to the call frame.
   279  func requireRuntimeLocationStackPointerEqual(t *testing.T, expSP uint64, c compiler) {
   280  	require.Equal(t, expSP, c.runtimeValueLocationStack().sp-callFrameDataSizeInUint64)
   281  }
   282  
   283  // TestCompileI32WrapFromI64 is the regression test for https://github.com/tetratelabs/wazero/issues/1008
   284  func TestCompileI32WrapFromI64(t *testing.T) {
   285  	c := newCompiler()
   286  	c.Init(&wasm.FunctionType{}, nil, false)
   287  
   288  	// Push the original i64 value.
   289  	loc := c.runtimeValueLocationStack().pushRuntimeValueLocationOnStack()
   290  	loc.valueType = runtimeValueTypeI64
   291  	// Wrap it as the i32, and this should result in having runtimeValueTypeI32 on top of the stack.
   292  	err := c.compileI32WrapFromI64()
   293  	require.NoError(t, err)
   294  	require.Equal(t, runtimeValueTypeI32, loc.valueType)
   295  }
   296  
   297  func operationPtr(operation wazeroir.UnionOperation) *wazeroir.UnionOperation {
   298  	return &operation
   299  }
   300  
   301  func requireExecutable(original []byte) (executable []byte) {
   302  	executable, err := platform.MmapCodeSegment(len(original))
   303  	if err != nil {
   304  		panic(err)
   305  	}
   306  	copy(executable, original)
   307  	makeExecutable(executable)
   308  	return executable
   309  }
   310  
   311  func makeExecutable(executable []byte) {
   312  	if runtime.GOARCH == "arm64" {
   313  		if err := platform.MprotectRX(executable); err != nil {
   314  			panic(err)
   315  		}
   316  	}
   317  }