wa-lang.org/wazero@v1.0.2/internal/engine/compiler/compiler_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"os"
     7  	"testing"
     8  	"unsafe"
     9  
    10  	"wa-lang.org/wazero/internal/platform"
    11  	"wa-lang.org/wazero/internal/testing/require"
    12  	"wa-lang.org/wazero/internal/wasm"
    13  	"wa-lang.org/wazero/internal/wazeroir"
    14  )
    15  
    16  func TestMain(m *testing.M) {
    17  	if !platform.CompilerSupported() {
    18  		os.Exit(0)
    19  	}
    20  	os.Exit(m.Run())
    21  }
    22  
    23  // Ensures that the offset consts do not drift when we manipulate the target
    24  // structs.
    25  //
    26  // Note: This is a package initializer as many tests could fail if these
    27  // constants are misaligned, hiding the root cause.
    28  func init() {
    29  	var me moduleEngine
    30  	requireEqual := func(expected, actual int, name string) {
    31  		if expected != actual {
    32  			panic(fmt.Sprintf("%s: expected %d, but was %d", name, expected, actual))
    33  		}
    34  	}
    35  	requireEqual(int(unsafe.Offsetof(me.functions)), moduleEngineFunctionsOffset, "moduleEngineFunctionsOffset")
    36  
    37  	var ce callEngine
    38  	// Offsets for callEngine.moduleContext.
    39  	requireEqual(int(unsafe.Offsetof(ce.fn)), callEngineModuleContextFnOffset, "callEngineModuleContextFnOffset")
    40  	requireEqual(int(unsafe.Offsetof(ce.moduleInstanceAddress)), callEngineModuleContextModuleInstanceAddressOffset, "callEngineModuleContextModuleInstanceAddressOffset")
    41  	requireEqual(int(unsafe.Offsetof(ce.globalElement0Address)), callEngineModuleContextGlobalElement0AddressOffset, "callEngineModuleContextGlobalElement0AddressOffset")
    42  	requireEqual(int(unsafe.Offsetof(ce.memoryElement0Address)), callEngineModuleContextMemoryElement0AddressOffset, "callEngineModuleContextMemoryElement0AddressOffset")
    43  	requireEqual(int(unsafe.Offsetof(ce.memorySliceLen)), callEngineModuleContextMemorySliceLenOffset, "callEngineModuleContextMemorySliceLenOffset")
    44  	requireEqual(int(unsafe.Offsetof(ce.memoryInstance)), callEngineModuleContextMemoryInstanceOffset, "callEngineModuleContextMemoryInstanceOffset")
    45  	requireEqual(int(unsafe.Offsetof(ce.tablesElement0Address)), callEngineModuleContextTablesElement0AddressOffset, "callEngineModuleContextTablesElement0AddressOffset")
    46  	requireEqual(int(unsafe.Offsetof(ce.functionsElement0Address)), callEngineModuleContextFunctionsElement0AddressOffset, "callEngineModuleContextFunctionsElement0AddressOffset")
    47  	requireEqual(int(unsafe.Offsetof(ce.typeIDsElement0Address)), callEngineModuleContextTypeIDsElement0AddressOffset, "callEngineModuleContextTypeIDsElement0AddressOffset")
    48  	requireEqual(int(unsafe.Offsetof(ce.dataInstancesElement0Address)), callEngineModuleContextDataInstancesElement0AddressOffset, "callEngineModuleContextDataInstancesElement0AddressOffset")
    49  	requireEqual(int(unsafe.Offsetof(ce.elementInstancesElement0Address)), callEngineModuleContextElementInstancesElement0AddressOffset, "callEngineModuleContextElementInstancesElement0AddressOffset")
    50  
    51  	// Offsets for callEngine.stackContext
    52  	requireEqual(int(unsafe.Offsetof(ce.stackPointer)), callEngineStackContextStackPointerOffset, "callEngineStackContextStackPointerOffset")
    53  	requireEqual(int(unsafe.Offsetof(ce.stackBasePointerInBytes)), callEngineStackContextStackBasePointerInBytesOffset, "callEngineStackContextStackBasePointerInBytesOffset")
    54  	requireEqual(int(unsafe.Offsetof(ce.stackElement0Address)), callEngineStackContextStackElement0AddressOffset, "callEngineStackContextStackElement0AddressOffset")
    55  	requireEqual(int(unsafe.Offsetof(ce.stackLenInBytes)), callEngineStackContextStackLenInBytesOffset, "callEngineStackContextStackLenInBytesOffset")
    56  
    57  	// Offsets for callEngine.exitContext.
    58  	requireEqual(int(unsafe.Offsetof(ce.statusCode)), callEngineExitContextNativeCallStatusCodeOffset, "callEngineExitContextNativeCallStatusCodeOffset")
    59  	requireEqual(int(unsafe.Offsetof(ce.builtinFunctionCallIndex)), callEngineExitContextBuiltinFunctionCallIndexOffset, "callEngineExitContextBuiltinFunctionCallIndexOffset")
    60  	requireEqual(int(unsafe.Offsetof(ce.returnAddress)), callEngineExitContextReturnAddressOffset, "callEngineExitContextReturnAddressOffset")
    61  
    62  	// Size and offsets for callFrame.
    63  	var frame callFrame
    64  	requireEqual(int(unsafe.Sizeof(frame))/8, callFrameDataSizeInUint64, "callFrameDataSize")
    65  
    66  	// Offsets for code.
    67  	var compiledFunc function
    68  	requireEqual(int(unsafe.Offsetof(compiledFunc.codeInitialAddress)), functionCodeInitialAddressOffset, "functionCodeInitialAddressOffset")
    69  	requireEqual(int(unsafe.Offsetof(compiledFunc.source)), functionSourceOffset, "functionSourceOffset")
    70  	requireEqual(int(unsafe.Offsetof(compiledFunc.moduleInstanceAddress)), functionModuleInstanceAddressOffset, "functionModuleInstanceAddressOffset")
    71  
    72  	// Offsets for wasm.ModuleInstance.
    73  	var moduleInstance wasm.ModuleInstance
    74  	requireEqual(int(unsafe.Offsetof(moduleInstance.Globals)), moduleInstanceGlobalsOffset, "moduleInstanceGlobalsOffset")
    75  	requireEqual(int(unsafe.Offsetof(moduleInstance.Memory)), moduleInstanceMemoryOffset, "moduleInstanceMemoryOffset")
    76  	requireEqual(int(unsafe.Offsetof(moduleInstance.Tables)), moduleInstanceTablesOffset, "moduleInstanceTablesOffset")
    77  	requireEqual(int(unsafe.Offsetof(moduleInstance.Engine)), moduleInstanceEngineOffset, "moduleInstanceEngineOffset")
    78  	requireEqual(int(unsafe.Offsetof(moduleInstance.TypeIDs)), moduleInstanceTypeIDsOffset, "moduleInstanceTypeIDsOffset")
    79  	requireEqual(int(unsafe.Offsetof(moduleInstance.DataInstances)), moduleInstanceDataInstancesOffset, "moduleInstanceDataInstancesOffset")
    80  	requireEqual(int(unsafe.Offsetof(moduleInstance.ElementInstances)), moduleInstanceElementInstancesOffset, "moduleInstanceElementInstancesOffset")
    81  
    82  	var functionInstance wasm.FunctionInstance
    83  	requireEqual(int(unsafe.Offsetof(functionInstance.TypeID)), functionInstanceTypeIDOffset, "functionInstanceTypeIDOffset")
    84  
    85  	// Offsets for wasm.Table.
    86  	var tableInstance wasm.TableInstance
    87  	requireEqual(int(unsafe.Offsetof(tableInstance.References)), tableInstanceTableOffset, "tableInstanceTableOffset")
    88  	// We add "+8" to get the length of Tables[0].Table
    89  	// since the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory.
    90  	requireEqual(int(unsafe.Offsetof(tableInstance.References)+8), tableInstanceTableLenOffset, "tableInstanceTableLenOffset")
    91  
    92  	// Offsets for wasm.Memory
    93  	var memoryInstance wasm.MemoryInstance
    94  	requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)), memoryInstanceBufferOffset, "memoryInstanceBufferOffset")
    95  	// "+8" because the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory.
    96  	requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)+8), memoryInstanceBufferLenOffset, "memoryInstanceBufferLenOffset")
    97  
    98  	// Offsets for wasm.GlobalInstance
    99  	var globalInstance wasm.GlobalInstance
   100  	requireEqual(int(unsafe.Offsetof(globalInstance.Val)), globalInstanceValueOffset, "globalInstanceValueOffset")
   101  
   102  	var dataInstance wasm.DataInstance
   103  	requireEqual(int(unsafe.Sizeof(dataInstance)), dataInstanceStructSize, "dataInstanceStructSize")
   104  
   105  	var elementInstance wasm.ElementInstance
   106  	requireEqual(int(unsafe.Sizeof(elementInstance)), elementInstanceStructSize, "elementInstanceStructSize")
   107  
   108  	var pointer uintptr
   109  	requireEqual(int(unsafe.Sizeof(pointer)), 1<<pointerSizeLog2, "pointerSizeLog2")
   110  }
   111  
   112  type compilerEnv struct {
   113  	me             *moduleEngine
   114  	ce             *callEngine
   115  	moduleInstance *wasm.ModuleInstance
   116  }
   117  
   118  func (j *compilerEnv) stackTopAsUint32() uint32 {
   119  	return uint32(j.stack()[j.ce.stackContext.stackPointer-1])
   120  }
   121  
   122  func (j *compilerEnv) stackTopAsInt32() int32 {
   123  	return int32(j.stack()[j.ce.stackContext.stackPointer-1])
   124  }
   125  
   126  func (j *compilerEnv) stackTopAsUint64() uint64 {
   127  	return j.stack()[j.ce.stackContext.stackPointer-1]
   128  }
   129  
   130  func (j *compilerEnv) stackTopAsInt64() int64 {
   131  	return int64(j.stack()[j.ce.stackContext.stackPointer-1])
   132  }
   133  
   134  func (j *compilerEnv) stackTopAsFloat32() float32 {
   135  	return math.Float32frombits(uint32(j.stack()[j.ce.stackContext.stackPointer-1]))
   136  }
   137  
   138  func (j *compilerEnv) stackTopAsFloat64() float64 {
   139  	return math.Float64frombits(j.stack()[j.ce.stackContext.stackPointer-1])
   140  }
   141  
   142  func (j *compilerEnv) stackTopAsV128() (lo uint64, hi uint64) {
   143  	st := j.stack()
   144  	return st[j.ce.stackContext.stackPointer-2], st[j.ce.stackContext.stackPointer-1]
   145  }
   146  
   147  func (j *compilerEnv) memory() []byte {
   148  	return j.moduleInstance.Memory.Buffer
   149  }
   150  
   151  func (j *compilerEnv) stack() []uint64 {
   152  	return j.ce.stack
   153  }
   154  
   155  func (j *compilerEnv) compilerStatus() nativeCallStatusCode {
   156  	return j.ce.exitContext.statusCode
   157  }
   158  
   159  func (j *compilerEnv) builtinFunctionCallAddress() wasm.Index {
   160  	return j.ce.exitContext.builtinFunctionCallIndex
   161  }
   162  
   163  // stackPointer returns the stack pointer minus the call frame.
   164  func (j *compilerEnv) stackPointer() uint64 {
   165  	return j.ce.stackContext.stackPointer - callFrameDataSizeInUint64
   166  }
   167  
   168  func (j *compilerEnv) stackBasePointer() uint64 {
   169  	return j.ce.stackContext.stackBasePointerInBytes >> 3
   170  }
   171  
   172  func (j *compilerEnv) setStackPointer(sp uint64) {
   173  	j.ce.stackContext.stackPointer = sp
   174  }
   175  
   176  func (j *compilerEnv) addGlobals(g ...*wasm.GlobalInstance) {
   177  	j.moduleInstance.Globals = append(j.moduleInstance.Globals, g...)
   178  }
   179  
   180  func (j *compilerEnv) globals() []*wasm.GlobalInstance {
   181  	return j.moduleInstance.Globals
   182  }
   183  
   184  func (j *compilerEnv) addTable(table *wasm.TableInstance) {
   185  	j.moduleInstance.Tables = append(j.moduleInstance.Tables, table)
   186  }
   187  
   188  func (j *compilerEnv) setStackBasePointer(sp uint64) {
   189  	j.ce.stackContext.stackBasePointerInBytes = sp << 3
   190  }
   191  
   192  func (j *compilerEnv) module() *wasm.ModuleInstance {
   193  	return j.moduleInstance
   194  }
   195  
   196  func (j *compilerEnv) moduleEngine() *moduleEngine {
   197  	return j.me
   198  }
   199  
   200  func (j *compilerEnv) callEngine() *callEngine {
   201  	return j.ce
   202  }
   203  
   204  func (j *compilerEnv) newFunction(codeSegment []byte) *function {
   205  	return &function{
   206  		parent:                &code{codeSegment: codeSegment},
   207  		codeInitialAddress:    uintptr(unsafe.Pointer(&codeSegment[0])),
   208  		moduleInstanceAddress: uintptr(unsafe.Pointer(j.moduleInstance)),
   209  		source: &wasm.FunctionInstance{
   210  			Type:   &wasm.FunctionType{},
   211  			Module: j.moduleInstance,
   212  		},
   213  	}
   214  }
   215  
   216  func (j *compilerEnv) exec(codeSegment []byte) {
   217  	f := j.newFunction(codeSegment)
   218  	j.ce.initialFn = f
   219  	j.ce.fn = f
   220  
   221  	nativecall(
   222  		uintptr(unsafe.Pointer(&codeSegment[0])),
   223  		uintptr(unsafe.Pointer(j.ce)),
   224  		uintptr(unsafe.Pointer(j.moduleInstance)),
   225  	)
   226  }
   227  
   228  // newTestCompiler allows us to test a different architecture than the current one.
   229  type newTestCompiler func(ir *wazeroir.CompilationResult, _ bool) (compiler, error)
   230  
   231  func (j *compilerEnv) requireNewCompiler(t *testing.T, fn newTestCompiler, ir *wazeroir.CompilationResult) compilerImpl {
   232  	requireSupportedOSArch(t)
   233  
   234  	if ir == nil {
   235  		ir = &wazeroir.CompilationResult{
   236  			LabelCallers: map[string]uint32{},
   237  			Signature:    &wasm.FunctionType{},
   238  		}
   239  	}
   240  	c, err := fn(ir, false)
   241  
   242  	require.NoError(t, err)
   243  
   244  	ret, ok := c.(compilerImpl)
   245  	require.True(t, ok)
   246  	return ret
   247  }
   248  
   249  // CompilerImpl is the interface used for architecture-independent unit tests in this pkg.
   250  // This is currently implemented by amd64 and arm64.
   251  type compilerImpl interface {
   252  	compiler
   253  	compileExitFromNativeCode(nativeCallStatusCode)
   254  	compileMaybeGrowStack() error
   255  	compileReturnFunction() error
   256  	getOnStackPointerCeilDeterminedCallBack() func(uint64)
   257  	setStackPointerCeil(uint64)
   258  	compileReleaseRegisterToStack(loc *runtimeValueLocation)
   259  	setRuntimeValueLocationStack(*runtimeValueLocationStack)
   260  	compileEnsureOnRegister(loc *runtimeValueLocation) error
   261  	compileModuleContextInitialization() error
   262  }
   263  
   264  const defaultMemoryPageNumInTest = 1
   265  
   266  func newCompilerEnvironment() *compilerEnv {
   267  	me := &moduleEngine{}
   268  	return &compilerEnv{
   269  		me: me,
   270  		moduleInstance: &wasm.ModuleInstance{
   271  			Memory:  &wasm.MemoryInstance{Buffer: make([]byte, wasm.MemoryPageSize*defaultMemoryPageNumInTest)},
   272  			Tables:  []*wasm.TableInstance{},
   273  			Globals: []*wasm.GlobalInstance{},
   274  			Engine:  me,
   275  		},
   276  		ce: me.newCallEngine(initialStackSize, nil),
   277  	}
   278  }
   279  
   280  // requireRuntimeLocationStackPointerEqual ensures that the compiler's runtimeValueLocationStack has
   281  // the expected stack pointer value relative to the call frame.
   282  func requireRuntimeLocationStackPointerEqual(t *testing.T, expSP uint64, c compiler) {
   283  	require.Equal(t, expSP, c.runtimeValueLocationStack().sp-callFrameDataSizeInUint64)
   284  }