wa-lang.org/wazero@v1.0.2/internal/engine/compiler/compiler_initialization_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"testing"
     7  	"unsafe"
     8  
     9  	"wa-lang.org/wazero/internal/testing/require"
    10  	"wa-lang.org/wazero/internal/wasm"
    11  	"wa-lang.org/wazero/internal/wazeroir"
    12  )
    13  
    14  func TestCompiler_compileModuleContextInitialization(t *testing.T) {
    15  	tests := []struct {
    16  		name           string
    17  		moduleInstance *wasm.ModuleInstance
    18  	}{
    19  		{
    20  			name: "no nil",
    21  			moduleInstance: &wasm.ModuleInstance{
    22  				Globals: []*wasm.GlobalInstance{{Val: 100}},
    23  				Memory:  &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    24  				Tables: []*wasm.TableInstance{
    25  					{References: make([]wasm.Reference, 20)},
    26  					{References: make([]wasm.Reference, 10)},
    27  				},
    28  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    29  				DataInstances:    make([][]byte, 10),
    30  				ElementInstances: make([]wasm.ElementInstance, 10),
    31  			},
    32  		},
    33  		{
    34  			name: "element instances nil",
    35  			moduleInstance: &wasm.ModuleInstance{
    36  				Globals:          []*wasm.GlobalInstance{{Val: 100}},
    37  				Memory:           &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    38  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    39  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    40  				DataInstances:    make([][]byte, 10),
    41  				ElementInstances: nil,
    42  			},
    43  		},
    44  		{
    45  			name: "data instances nil",
    46  			moduleInstance: &wasm.ModuleInstance{
    47  				Globals:          []*wasm.GlobalInstance{{Val: 100}},
    48  				Memory:           &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    49  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    50  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    51  				DataInstances:    nil,
    52  				ElementInstances: make([]wasm.ElementInstance, 10),
    53  			},
    54  		},
    55  		{
    56  			name: "globals nil",
    57  			moduleInstance: &wasm.ModuleInstance{
    58  				Memory:           &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    59  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    60  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    61  				DataInstances:    make([][]byte, 10),
    62  				ElementInstances: make([]wasm.ElementInstance, 10),
    63  			},
    64  		},
    65  		{
    66  			name: "memory nil",
    67  			moduleInstance: &wasm.ModuleInstance{
    68  				Globals:          []*wasm.GlobalInstance{{Val: 100}},
    69  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    70  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    71  				DataInstances:    make([][]byte, 10),
    72  				ElementInstances: make([]wasm.ElementInstance, 10),
    73  			},
    74  		},
    75  		{
    76  			name: "table nil",
    77  			moduleInstance: &wasm.ModuleInstance{
    78  				Memory:           &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    79  				Tables:           []*wasm.TableInstance{{References: nil}},
    80  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    81  				DataInstances:    make([][]byte, 10),
    82  				ElementInstances: make([]wasm.ElementInstance, 10),
    83  			},
    84  		},
    85  		{
    86  			name: "table empty",
    87  			moduleInstance: &wasm.ModuleInstance{
    88  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    89  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    90  				DataInstances:    make([][]byte, 10),
    91  				ElementInstances: make([]wasm.ElementInstance, 10),
    92  			},
    93  		},
    94  		{
    95  			name: "memory zero length",
    96  			moduleInstance: &wasm.ModuleInstance{
    97  				Memory: &wasm.MemoryInstance{Buffer: make([]byte, 0)},
    98  			},
    99  		},
   100  		{
   101  			name:           "all nil except mod engine",
   102  			moduleInstance: &wasm.ModuleInstance{},
   103  		},
   104  	}
   105  
   106  	for _, tt := range tests {
   107  		tc := tt
   108  		t.Run(tc.name, func(t *testing.T) {
   109  			env := newCompilerEnvironment()
   110  			env.moduleInstance = tc.moduleInstance
   111  			ce := env.callEngine()
   112  
   113  			ir := &wazeroir.CompilationResult{
   114  				HasMemory:           tc.moduleInstance.Memory != nil,
   115  				HasTable:            len(tc.moduleInstance.Tables) > 0,
   116  				HasDataInstances:    len(tc.moduleInstance.DataInstances) > 0,
   117  				HasElementInstances: len(tc.moduleInstance.ElementInstances) > 0,
   118  			}
   119  			for _, g := range tc.moduleInstance.Globals {
   120  				ir.Globals = append(ir.Globals, g.Type)
   121  			}
   122  			compiler := env.requireNewCompiler(t, newCompiler, ir)
   123  			me := &moduleEngine{functions: make([]*function, 10)}
   124  			tc.moduleInstance.Engine = me
   125  
   126  			err := compiler.compileModuleContextInitialization()
   127  			require.NoError(t, err)
   128  			require.Zero(t, len(compiler.runtimeValueLocationStack().usedRegisters), "expected no usedRegisters")
   129  
   130  			compiler.compileExitFromNativeCode(nativeCallStatusCodeReturned)
   131  
   132  			// Generate the code under test.
   133  			code, _, err := compiler.compile()
   134  			require.NoError(t, err)
   135  
   136  			env.exec(code)
   137  
   138  			// Check the exit status.
   139  			require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
   140  
   141  			// Check if the fields of callEngine.moduleContext are updated.
   142  			bufSliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Globals))
   143  			require.Equal(t, bufSliceHeader.Data, ce.moduleContext.globalElement0Address)
   144  
   145  			if tc.moduleInstance.Memory != nil {
   146  				bufSliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Memory.Buffer))
   147  				require.Equal(t, uint64(bufSliceHeader.Len), ce.moduleContext.memorySliceLen)
   148  				require.Equal(t, bufSliceHeader.Data, ce.moduleContext.memoryElement0Address)
   149  				require.Equal(t, tc.moduleInstance.Memory, ce.moduleContext.memoryInstance)
   150  			}
   151  
   152  			if len(tc.moduleInstance.Tables) > 0 {
   153  				tableHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Tables))
   154  				require.Equal(t, tableHeader.Data, ce.moduleContext.tablesElement0Address)
   155  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.TypeIDs[0])), ce.moduleContext.typeIDsElement0Address)
   156  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.Tables[0])), ce.moduleContext.tablesElement0Address)
   157  			}
   158  
   159  			if len(tc.moduleInstance.DataInstances) > 0 {
   160  				dataInstancesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.DataInstances))
   161  				require.Equal(t, dataInstancesHeader.Data, ce.moduleContext.dataInstancesElement0Address)
   162  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.DataInstances[0])), ce.moduleContext.dataInstancesElement0Address)
   163  			}
   164  
   165  			if len(tc.moduleInstance.ElementInstances) > 0 {
   166  				elementInstancesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.ElementInstances))
   167  				require.Equal(t, elementInstancesHeader.Data, ce.moduleContext.elementInstancesElement0Address)
   168  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.ElementInstances[0])), ce.moduleContext.elementInstancesElement0Address)
   169  			}
   170  
   171  			require.Equal(t, uintptr(unsafe.Pointer(&me.functions[0])), ce.moduleContext.functionsElement0Address)
   172  		})
   173  	}
   174  }
   175  
   176  func TestCompiler_compileMaybeGrowStack(t *testing.T) {
   177  	t.Run("not grow", func(t *testing.T) {
   178  		const stackPointerCeil = 5
   179  		for _, baseOffset := range []uint64{5, 10, 20} {
   180  			t.Run(fmt.Sprintf("%d", baseOffset), func(t *testing.T) {
   181  				env := newCompilerEnvironment()
   182  				compiler := env.requireNewCompiler(t, newCompiler, nil)
   183  
   184  				err := compiler.compilePreamble()
   185  				require.NoError(t, err)
   186  
   187  				require.NotNil(t, compiler.getOnStackPointerCeilDeterminedCallBack())
   188  
   189  				stackLen := uint64(len(env.stack()))
   190  				stackBasePointer := stackLen - baseOffset // Ceil <= stackLen - stackBasePointer = no need to grow!
   191  				compiler.getOnStackPointerCeilDeterminedCallBack()(stackPointerCeil)
   192  				env.setStackBasePointer(stackBasePointer)
   193  
   194  				compiler.compileExitFromNativeCode(nativeCallStatusCodeReturned)
   195  
   196  				// Generate and run the code under test.
   197  				code, _, err := compiler.compile()
   198  				require.NoError(t, err)
   199  				env.exec(code)
   200  
   201  				// The status code must be "Returned", not "BuiltinFunctionCall".
   202  				require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
   203  			})
   204  		}
   205  	})
   206  
   207  	defaultStackLen := uint64(initialStackSize)
   208  	t.Run("grow", func(t *testing.T) {
   209  		tests := []struct {
   210  			name             string
   211  			stackPointerCeil uint64
   212  			stackBasePointer uint64
   213  		}{
   214  			{
   215  				name:             "ceil=6/sbp=len-5",
   216  				stackPointerCeil: 6,
   217  				stackBasePointer: defaultStackLen - 5,
   218  			},
   219  			{
   220  				name:             "ceil=10000/sbp=0",
   221  				stackPointerCeil: 10000,
   222  				stackBasePointer: 0,
   223  			},
   224  		}
   225  
   226  		for _, tc := range tests {
   227  			tc := tc
   228  			t.Run(tc.name, func(t *testing.T) {
   229  				env := newCompilerEnvironment()
   230  				compiler := env.requireNewCompiler(t, newCompiler, nil)
   231  
   232  				err := compiler.compilePreamble()
   233  				require.NoError(t, err)
   234  
   235  				// On the return from grow value stack, we simply return.
   236  				err = compiler.compileReturnFunction()
   237  				require.NoError(t, err)
   238  
   239  				// Generate code under test with the given stackPointerCeil.
   240  				compiler.setStackPointerCeil(tc.stackPointerCeil)
   241  				code, _, err := compiler.compile()
   242  				require.NoError(t, err)
   243  
   244  				// And run the code with the specified stackBasePointer.
   245  				env.setStackBasePointer(tc.stackBasePointer)
   246  				env.exec(code)
   247  
   248  				// Check if the call exits with builtin function call status.
   249  				require.Equal(t, nativeCallStatusCodeCallBuiltInFunction, env.compilerStatus())
   250  
   251  				// Reenter from the return address.
   252  				returnAddress := env.ce.returnAddress
   253  				require.True(t, returnAddress != 0, "returnAddress was zero %d", returnAddress)
   254  				nativecall(
   255  					returnAddress, uintptr(unsafe.Pointer(env.callEngine())),
   256  					uintptr(unsafe.Pointer(env.module())),
   257  				)
   258  
   259  				// Check the result. This should be "Returned".
   260  				require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
   261  			})
   262  		}
   263  	})
   264  }