github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/compiler/compiler_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"os"
     7  	"runtime"
     8  	"testing"
     9  	"unsafe"
    10  
    11  	"github.com/bananabytelabs/wazero/internal/platform"
    12  	"github.com/bananabytelabs/wazero/internal/testing/require"
    13  	"github.com/bananabytelabs/wazero/internal/wasm"
    14  	"github.com/bananabytelabs/wazero/internal/wazeroir"
    15  )
    16  
    17  func TestMain(m *testing.M) {
    18  	if !platform.CompilerSupported() {
    19  		os.Exit(0)
    20  	}
    21  	os.Exit(m.Run())
    22  }
    23  
    24  // Ensures that the offset consts do not drift when we manipulate the target
    25  // structs.
    26  //
    27  // Note: This is a package initializer as many tests could fail if these
    28  // constants are misaligned, hiding the root cause.
    29  func init() {
    30  	var me moduleEngine
    31  	requireEqual := func(expected, actual int, name string) {
    32  		if expected != actual {
    33  			panic(fmt.Sprintf("%s: expected %d, but was %d", name, expected, actual))
    34  		}
    35  	}
    36  	requireEqual(int(unsafe.Offsetof(me.functions)), moduleEngineFunctionsOffset, "moduleEngineFunctionsOffset")
    37  
    38  	var ce callEngine
    39  	// Offsets for callEngine.moduleContext.
    40  	requireEqual(int(unsafe.Offsetof(ce.fn)), callEngineModuleContextFnOffset, "callEngineModuleContextFnOffset")
    41  	requireEqual(int(unsafe.Offsetof(ce.moduleInstance)), callEngineModuleContextModuleInstanceOffset, "callEngineModuleContextModuleInstanceOffset")
    42  	requireEqual(int(unsafe.Offsetof(ce.globalElement0Address)), callEngineModuleContextGlobalElement0AddressOffset, "callEngineModuleContextGlobalElement0AddressOffset")
    43  	requireEqual(int(unsafe.Offsetof(ce.memoryElement0Address)), callEngineModuleContextMemoryElement0AddressOffset, "callEngineModuleContextMemoryElement0AddressOffset")
    44  	requireEqual(int(unsafe.Offsetof(ce.memorySliceLen)), callEngineModuleContextMemorySliceLenOffset, "callEngineModuleContextMemorySliceLenOffset")
    45  	requireEqual(int(unsafe.Offsetof(ce.memoryInstance)), callEngineModuleContextMemoryInstanceOffset, "callEngineModuleContextMemoryInstanceOffset")
    46  	requireEqual(int(unsafe.Offsetof(ce.tablesElement0Address)), callEngineModuleContextTablesElement0AddressOffset, "callEngineModuleContextTablesElement0AddressOffset")
    47  	requireEqual(int(unsafe.Offsetof(ce.functionsElement0Address)), callEngineModuleContextFunctionsElement0AddressOffset, "callEngineModuleContextFunctionsElement0AddressOffset")
    48  	requireEqual(int(unsafe.Offsetof(ce.typeIDsElement0Address)), callEngineModuleContextTypeIDsElement0AddressOffset, "callEngineModuleContextTypeIDsElement0AddressOffset")
    49  	requireEqual(int(unsafe.Offsetof(ce.dataInstancesElement0Address)), callEngineModuleContextDataInstancesElement0AddressOffset, "callEngineModuleContextDataInstancesElement0AddressOffset")
    50  	requireEqual(int(unsafe.Offsetof(ce.elementInstancesElement0Address)), callEngineModuleContextElementInstancesElement0AddressOffset, "callEngineModuleContextElementInstancesElement0AddressOffset")
    51  
    52  	// Offsets for callEngine.stackContext
    53  	requireEqual(int(unsafe.Offsetof(ce.stackPointer)), callEngineStackContextStackPointerOffset, "callEngineStackContextStackPointerOffset")
    54  	requireEqual(int(unsafe.Offsetof(ce.stackBasePointerInBytes)), callEngineStackContextStackBasePointerInBytesOffset, "callEngineStackContextStackBasePointerInBytesOffset")
    55  	requireEqual(int(unsafe.Offsetof(ce.stackElement0Address)), callEngineStackContextStackElement0AddressOffset, "callEngineStackContextStackElement0AddressOffset")
    56  	requireEqual(int(unsafe.Offsetof(ce.stackLenInBytes)), callEngineStackContextStackLenInBytesOffset, "callEngineStackContextStackLenInBytesOffset")
    57  
    58  	// Offsets for callEngine.exitContext.
    59  	requireEqual(int(unsafe.Offsetof(ce.statusCode)), callEngineExitContextNativeCallStatusCodeOffset, "callEngineExitContextNativeCallStatusCodeOffset")
    60  	requireEqual(int(unsafe.Offsetof(ce.builtinFunctionCallIndex)), callEngineExitContextBuiltinFunctionCallIndexOffset, "callEngineExitContextBuiltinFunctionCallIndexOffset")
    61  	requireEqual(int(unsafe.Offsetof(ce.returnAddress)), callEngineExitContextReturnAddressOffset, "callEngineExitContextReturnAddressOffset")
    62  	requireEqual(int(unsafe.Offsetof(ce.callerModuleInstance)), callEngineExitContextCallerModuleInstanceOffset, "callEngineExitContextCallerModuleInstanceOffset")
    63  
    64  	// Size and offsets for callFrame.
    65  	var frame callFrame
    66  	requireEqual(int(unsafe.Sizeof(frame))/8, callFrameDataSizeInUint64, "callFrameDataSize")
    67  
    68  	// Offsets for code.
    69  	var f function
    70  	requireEqual(int(unsafe.Offsetof(f.codeInitialAddress)), functionCodeInitialAddressOffset, "functionCodeInitialAddressOffset")
    71  	requireEqual(int(unsafe.Offsetof(f.moduleInstance)), functionModuleInstanceOffset, "functionModuleInstanceOffset")
    72  	requireEqual(int(unsafe.Offsetof(f.typeID)), functionTypeIDOffset, "functionTypeIDOffset")
    73  	requireEqual(int(unsafe.Sizeof(f)), functionSize, "functionModuleInstanceOffset")
    74  
    75  	// Offsets for wasm.ModuleInstance.
    76  	var moduleInstance wasm.ModuleInstance
    77  	requireEqual(int(unsafe.Offsetof(moduleInstance.Globals)), moduleInstanceGlobalsOffset, "moduleInstanceGlobalsOffset")
    78  	requireEqual(int(unsafe.Offsetof(moduleInstance.MemoryInstance)), moduleInstanceMemoryOffset, "moduleInstanceMemoryOffset")
    79  	requireEqual(int(unsafe.Offsetof(moduleInstance.Tables)), moduleInstanceTablesOffset, "moduleInstanceTablesOffset")
    80  	requireEqual(int(unsafe.Offsetof(moduleInstance.Engine)), moduleInstanceEngineOffset, "moduleInstanceEngineOffset")
    81  	requireEqual(int(unsafe.Offsetof(moduleInstance.TypeIDs)), moduleInstanceTypeIDsOffset, "moduleInstanceTypeIDsOffset")
    82  	requireEqual(int(unsafe.Offsetof(moduleInstance.DataInstances)), moduleInstanceDataInstancesOffset, "moduleInstanceDataInstancesOffset")
    83  	requireEqual(int(unsafe.Offsetof(moduleInstance.ElementInstances)), moduleInstanceElementInstancesOffset, "moduleInstanceElementInstancesOffset")
    84  
    85  	// Offsets for wasm.Table.
    86  	var tableInstance wasm.TableInstance
    87  	requireEqual(int(unsafe.Offsetof(tableInstance.References)), tableInstanceTableOffset, "tableInstanceTableOffset")
    88  	// We add "+8" to get the length of Tables[0].Table
    89  	// since the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory.
    90  	requireEqual(int(unsafe.Offsetof(tableInstance.References)+8), tableInstanceTableLenOffset, "tableInstanceTableLenOffset")
    91  
    92  	// Offsets for wasm.Memory
    93  	var memoryInstance wasm.MemoryInstance
    94  	requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)), memoryInstanceBufferOffset, "memoryInstanceBufferOffset")
    95  	// "+8" because the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory.
    96  	requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)+8), memoryInstanceBufferLenOffset, "memoryInstanceBufferLenOffset")
    97  
    98  	// Offsets for wasm.GlobalInstance
    99  	var globalInstance wasm.GlobalInstance
   100  	requireEqual(int(unsafe.Offsetof(globalInstance.Val)), globalInstanceValueOffset, "globalInstanceValueOffset")
   101  
   102  	var dataInstance wasm.DataInstance
   103  	requireEqual(int(unsafe.Sizeof(dataInstance)), dataInstanceStructSize, "dataInstanceStructSize")
   104  
   105  	var elementInstance wasm.ElementInstance
   106  	requireEqual(int(unsafe.Sizeof(elementInstance)), elementInstanceStructSize, "elementInstanceStructSize")
   107  
   108  	var pointer uintptr
   109  	requireEqual(int(unsafe.Sizeof(pointer)), 1<<pointerSizeLog2, "pointerSizeLog2")
   110  }
   111  
   112  type compilerEnv struct {
   113  	me             *moduleEngine
   114  	ce             *callEngine
   115  	moduleInstance *wasm.ModuleInstance
   116  }
   117  
   118  func (j *compilerEnv) stackTopAsUint32() uint32 {
   119  	return uint32(j.stack()[j.ce.stackContext.stackPointer-1])
   120  }
   121  
   122  func (j *compilerEnv) stackTopAsInt32() int32 {
   123  	return int32(j.stack()[j.ce.stackContext.stackPointer-1])
   124  }
   125  
   126  func (j *compilerEnv) stackTopAsUint64() uint64 {
   127  	return j.stack()[j.ce.stackContext.stackPointer-1]
   128  }
   129  
   130  func (j *compilerEnv) stackTopAsInt64() int64 {
   131  	return int64(j.stack()[j.ce.stackContext.stackPointer-1])
   132  }
   133  
   134  func (j *compilerEnv) stackTopAsFloat32() float32 {
   135  	return math.Float32frombits(uint32(j.stack()[j.ce.stackContext.stackPointer-1]))
   136  }
   137  
   138  func (j *compilerEnv) stackTopAsFloat64() float64 {
   139  	return math.Float64frombits(j.stack()[j.ce.stackContext.stackPointer-1])
   140  }
   141  
   142  func (j *compilerEnv) stackTopAsV128() (lo uint64, hi uint64) {
   143  	st := j.stack()
   144  	return st[j.ce.stackContext.stackPointer-2], st[j.ce.stackContext.stackPointer-1]
   145  }
   146  
   147  func (j *compilerEnv) memory() []byte {
   148  	return j.moduleInstance.MemoryInstance.Buffer
   149  }
   150  
   151  func (j *compilerEnv) stack() []uint64 {
   152  	return j.ce.stack
   153  }
   154  
   155  func (j *compilerEnv) compilerStatus() nativeCallStatusCode {
   156  	return j.ce.exitContext.statusCode
   157  }
   158  
   159  func (j *compilerEnv) builtinFunctionCallAddress() wasm.Index {
   160  	return j.ce.exitContext.builtinFunctionCallIndex
   161  }
   162  
   163  // stackPointer returns the stack pointer minus the call frame.
   164  func (j *compilerEnv) stackPointer() uint64 {
   165  	return j.ce.stackContext.stackPointer - callFrameDataSizeInUint64
   166  }
   167  
   168  func (j *compilerEnv) stackBasePointer() uint64 {
   169  	return j.ce.stackContext.stackBasePointerInBytes >> 3
   170  }
   171  
   172  func (j *compilerEnv) setStackPointer(sp uint64) {
   173  	j.ce.stackContext.stackPointer = sp
   174  }
   175  
   176  func (j *compilerEnv) addGlobals(g ...*wasm.GlobalInstance) {
   177  	j.moduleInstance.Globals = append(j.moduleInstance.Globals, g...)
   178  }
   179  
   180  func (j *compilerEnv) globals() []*wasm.GlobalInstance {
   181  	return j.moduleInstance.Globals
   182  }
   183  
   184  func (j *compilerEnv) addTable(table *wasm.TableInstance) {
   185  	j.moduleInstance.Tables = append(j.moduleInstance.Tables, table)
   186  }
   187  
   188  func (j *compilerEnv) setStackBasePointer(sp uint64) {
   189  	j.ce.stackContext.stackBasePointerInBytes = sp << 3
   190  }
   191  
   192  func (j *compilerEnv) module() *wasm.ModuleInstance {
   193  	return j.moduleInstance
   194  }
   195  
   196  func (j *compilerEnv) moduleEngine() *moduleEngine {
   197  	return j.me
   198  }
   199  
   200  func (j *compilerEnv) callEngine() *callEngine {
   201  	return j.ce
   202  }
   203  
   204  func (j *compilerEnv) exec(machineCode []byte) {
   205  	cm := &compiledModule{compiledCode: &compiledCode{}}
   206  	if err := cm.executable.Map(len(machineCode)); err != nil {
   207  		panic(err)
   208  	}
   209  	executable := cm.executable.Bytes()
   210  	copy(executable, machineCode)
   211  	makeExecutable(executable)
   212  
   213  	f := &function{
   214  		parent:             &compiledFunction{parent: cm.compiledCode},
   215  		codeInitialAddress: uintptr(unsafe.Pointer(&executable[0])),
   216  		moduleInstance:     j.moduleInstance,
   217  	}
   218  	j.ce.initialFn = f
   219  	j.ce.fn = f
   220  
   221  	nativecall(
   222  		uintptr(unsafe.Pointer(&executable[0])),
   223  		j.ce, j.moduleInstance,
   224  	)
   225  }
   226  
   227  func (j *compilerEnv) requireNewCompiler(t *testing.T, functionType *wasm.FunctionType, fn func() compiler, ir *wazeroir.CompilationResult) compilerImpl {
   228  	requireSupportedOSArch(t)
   229  
   230  	if ir == nil {
   231  		ir = &wazeroir.CompilationResult{
   232  			LabelCallers: map[wazeroir.Label]uint32{},
   233  		}
   234  	}
   235  
   236  	c := fn()
   237  	c.Init(functionType, ir, false)
   238  
   239  	ret, ok := c.(compilerImpl)
   240  	require.True(t, ok)
   241  	return ret
   242  }
   243  
   244  // compilerImpl is the interface used for architecture-independent unit tests in this pkg.
   245  // This is currently implemented by amd64 and arm64.
   246  type compilerImpl interface {
   247  	compiler
   248  	compileExitFromNativeCode(nativeCallStatusCode)
   249  	compileMaybeGrowStack() error
   250  	compileReturnFunction() error
   251  	assignStackPointerCeil(uint64)
   252  	setStackPointerCeil(uint64)
   253  	compileReleaseRegisterToStack(loc *runtimeValueLocation)
   254  	setRuntimeValueLocationStack(*runtimeValueLocationStack)
   255  	compileEnsureOnRegister(loc *runtimeValueLocation) error
   256  	compileModuleContextInitialization() error
   257  }
   258  
   259  const defaultMemoryPageNumInTest = 1
   260  
   261  func newCompilerEnvironment() *compilerEnv {
   262  	me := &moduleEngine{}
   263  	return &compilerEnv{
   264  		me: me,
   265  		moduleInstance: &wasm.ModuleInstance{
   266  			MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, wasm.MemoryPageSize*defaultMemoryPageNumInTest)},
   267  			Tables:         []*wasm.TableInstance{},
   268  			Globals:        []*wasm.GlobalInstance{},
   269  			Engine:         me,
   270  		},
   271  		ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledCode{}}}),
   272  	}
   273  }
   274  
   275  // requireRuntimeLocationStackPointerEqual ensures that the compiler's runtimeValueLocationStack has
   276  // the expected stack pointer value relative to the call frame.
   277  func requireRuntimeLocationStackPointerEqual(t *testing.T, expSP uint64, c compiler) {
   278  	require.Equal(t, expSP, c.runtimeValueLocationStack().sp-callFrameDataSizeInUint64)
   279  }
   280  
   281  // TestCompileI32WrapFromI64 is the regression test for https://github.com/bananabytelabs/wazero/issues/1008
   282  func TestCompileI32WrapFromI64(t *testing.T) {
   283  	c := newCompiler()
   284  	c.Init(&wasm.FunctionType{}, nil, false)
   285  
   286  	// Push the original i64 value.
   287  	loc := c.runtimeValueLocationStack().pushRuntimeValueLocationOnStack()
   288  	loc.valueType = runtimeValueTypeI64
   289  	// Wrap it as the i32, and this should result in having runtimeValueTypeI32 on top of the stack.
   290  	err := c.compileI32WrapFromI64()
   291  	require.NoError(t, err)
   292  	require.Equal(t, runtimeValueTypeI32, loc.valueType)
   293  }
   294  
   295  func operationPtr(operation wazeroir.UnionOperation) *wazeroir.UnionOperation {
   296  	return &operation
   297  }
   298  
   299  func requireExecutable(original []byte) (executable []byte) {
   300  	executable, err := platform.MmapCodeSegment(len(original))
   301  	if err != nil {
   302  		panic(err)
   303  	}
   304  	copy(executable, original)
   305  	makeExecutable(executable)
   306  	return executable
   307  }
   308  
   309  func makeExecutable(executable []byte) {
   310  	if runtime.GOARCH == "arm64" {
   311  		if err := platform.MprotectRX(executable); err != nil {
   312  			panic(err)
   313  		}
   314  	}
   315  }