github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/internal/engine/compiler/compiler_initialization_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"testing"
     7  	"unsafe"
     8  
     9  	"github.com/wasilibs/wazerox/internal/asm"
    10  	"github.com/wasilibs/wazerox/internal/testing/require"
    11  	"github.com/wasilibs/wazerox/internal/wasm"
    12  	"github.com/wasilibs/wazerox/internal/wazeroir"
    13  )
    14  
    15  func TestCompiler_compileModuleContextInitialization(t *testing.T) {
    16  	tests := []struct {
    17  		name           string
    18  		moduleInstance *wasm.ModuleInstance
    19  	}{
    20  		{
    21  			name: "no nil",
    22  			moduleInstance: &wasm.ModuleInstance{
    23  				Globals:        []*wasm.GlobalInstance{{Val: 100}},
    24  				MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    25  				Tables: []*wasm.TableInstance{
    26  					{References: make([]wasm.Reference, 20)},
    27  					{References: make([]wasm.Reference, 10)},
    28  				},
    29  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    30  				DataInstances:    make([][]byte, 10),
    31  				ElementInstances: make([]wasm.ElementInstance, 10),
    32  			},
    33  		},
    34  		{
    35  			name: "element instances nil",
    36  			moduleInstance: &wasm.ModuleInstance{
    37  				Globals:          []*wasm.GlobalInstance{{Val: 100}},
    38  				MemoryInstance:   &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    39  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    40  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    41  				DataInstances:    make([][]byte, 10),
    42  				ElementInstances: nil,
    43  			},
    44  		},
    45  		{
    46  			name: "data instances nil",
    47  			moduleInstance: &wasm.ModuleInstance{
    48  				Globals:          []*wasm.GlobalInstance{{Val: 100}},
    49  				MemoryInstance:   &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    50  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    51  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    52  				DataInstances:    nil,
    53  				ElementInstances: make([]wasm.ElementInstance, 10),
    54  			},
    55  		},
    56  		{
    57  			name: "globals nil",
    58  			moduleInstance: &wasm.ModuleInstance{
    59  				MemoryInstance:   &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    60  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    61  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    62  				DataInstances:    make([][]byte, 10),
    63  				ElementInstances: make([]wasm.ElementInstance, 10),
    64  			},
    65  		},
    66  		{
    67  			name: "memory nil",
    68  			moduleInstance: &wasm.ModuleInstance{
    69  				Globals:          []*wasm.GlobalInstance{{Val: 100}},
    70  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    71  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    72  				DataInstances:    make([][]byte, 10),
    73  				ElementInstances: make([]wasm.ElementInstance, 10),
    74  			},
    75  		},
    76  		{
    77  			name: "table nil",
    78  			moduleInstance: &wasm.ModuleInstance{
    79  				MemoryInstance:   &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    80  				Tables:           []*wasm.TableInstance{{References: nil}},
    81  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    82  				DataInstances:    make([][]byte, 10),
    83  				ElementInstances: make([]wasm.ElementInstance, 10),
    84  			},
    85  		},
    86  		{
    87  			name: "table empty",
    88  			moduleInstance: &wasm.ModuleInstance{
    89  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    90  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    91  				DataInstances:    make([][]byte, 10),
    92  				ElementInstances: make([]wasm.ElementInstance, 10),
    93  			},
    94  		},
    95  		{
    96  			name: "memory zero length",
    97  			moduleInstance: &wasm.ModuleInstance{
    98  				MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 0)},
    99  			},
   100  		},
   101  		{
   102  			name:           "all nil except mod engine",
   103  			moduleInstance: &wasm.ModuleInstance{},
   104  		},
   105  	}
   106  
   107  	for _, tt := range tests {
   108  		tc := tt
   109  		t.Run(tc.name, func(t *testing.T) {
   110  			env := newCompilerEnvironment()
   111  			env.moduleInstance = tc.moduleInstance
   112  			ce := env.callEngine()
   113  
   114  			ir := &wazeroir.CompilationResult{
   115  				HasMemory:           tc.moduleInstance.MemoryInstance != nil,
   116  				HasTable:            len(tc.moduleInstance.Tables) > 0,
   117  				HasDataInstances:    len(tc.moduleInstance.DataInstances) > 0,
   118  				HasElementInstances: len(tc.moduleInstance.ElementInstances) > 0,
   119  			}
   120  			for _, g := range tc.moduleInstance.Globals {
   121  				ir.Globals = append(ir.Globals, g.Type)
   122  			}
   123  			compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, ir)
   124  			me := &moduleEngine{functions: make([]function, 10)}
   125  			tc.moduleInstance.Engine = me
   126  
   127  			err := compiler.compileModuleContextInitialization()
   128  			require.NoError(t, err)
   129  			require.Zero(t, len(compiler.runtimeValueLocationStack().usedRegisters.list()), "expected no usedRegisters")
   130  
   131  			compiler.compileExitFromNativeCode(nativeCallStatusCodeReturned)
   132  
   133  			code := asm.CodeSegment{}
   134  			defer func() { require.NoError(t, code.Unmap()) }()
   135  
   136  			// Generate the code under test.
   137  			_, err = compiler.compile(code.NextCodeSection())
   138  			require.NoError(t, err)
   139  
   140  			env.exec(code.Bytes())
   141  
   142  			// Check the exit status.
   143  			require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
   144  
   145  			// Check if the fields of callEngine.moduleContext are updated.
   146  			bufSliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Globals))
   147  			require.Equal(t, bufSliceHeader.Data, ce.moduleContext.globalElement0Address)
   148  
   149  			if tc.moduleInstance.MemoryInstance != nil {
   150  				bufSliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.MemoryInstance.Buffer))
   151  				require.Equal(t, bufSliceHeader.Data, ce.moduleContext.memoryElement0Address)
   152  				require.Equal(t, tc.moduleInstance.MemoryInstance, ce.moduleContext.memoryInstance)
   153  			}
   154  
   155  			if len(tc.moduleInstance.Tables) > 0 {
   156  				tableHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Tables))
   157  				require.Equal(t, tableHeader.Data, ce.moduleContext.tablesElement0Address)
   158  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.TypeIDs[0])), ce.moduleContext.typeIDsElement0Address)
   159  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.Tables[0])), ce.moduleContext.tablesElement0Address)
   160  			}
   161  
   162  			if len(tc.moduleInstance.DataInstances) > 0 {
   163  				dataInstancesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.DataInstances))
   164  				require.Equal(t, dataInstancesHeader.Data, ce.moduleContext.dataInstancesElement0Address)
   165  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.DataInstances[0])), ce.moduleContext.dataInstancesElement0Address)
   166  			}
   167  
   168  			if len(tc.moduleInstance.ElementInstances) > 0 {
   169  				elementInstancesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.ElementInstances))
   170  				require.Equal(t, elementInstancesHeader.Data, ce.moduleContext.elementInstancesElement0Address)
   171  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.ElementInstances[0])), ce.moduleContext.elementInstancesElement0Address)
   172  			}
   173  
   174  			require.Equal(t, uintptr(unsafe.Pointer(&me.functions[0])), ce.moduleContext.functionsElement0Address)
   175  		})
   176  	}
   177  }
   178  
   179  func TestCompiler_compileMaybeGrowStack(t *testing.T) {
   180  	t.Run("not grow", func(t *testing.T) {
   181  		const stackPointerCeil = 5
   182  		for _, baseOffset := range []uint64{5, 10, 20} {
   183  			t.Run(fmt.Sprintf("%d", baseOffset), func(t *testing.T) {
   184  				env := newCompilerEnvironment()
   185  				compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, nil)
   186  
   187  				err := compiler.compilePreamble()
   188  				require.NoError(t, err)
   189  
   190  				stackLen := uint64(len(env.stack()))
   191  				stackBasePointer := stackLen - baseOffset // Ceil <= stackLen - stackBasePointer = no need to grow!
   192  				compiler.assignStackPointerCeil(stackPointerCeil)
   193  				env.setStackBasePointer(stackBasePointer)
   194  
   195  				compiler.compileExitFromNativeCode(nativeCallStatusCodeReturned)
   196  
   197  				code := asm.CodeSegment{}
   198  				defer func() { require.NoError(t, code.Unmap()) }()
   199  
   200  				// Generate and run the code under test.
   201  				_, err = compiler.compile(code.NextCodeSection())
   202  				require.NoError(t, err)
   203  				env.exec(code.Bytes())
   204  
   205  				// The status code must be "Returned", not "BuiltinFunctionCall".
   206  				require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
   207  			})
   208  		}
   209  	})
   210  
   211  	defaultStackLen := uint64(initialStackSize)
   212  	t.Run("grow", func(t *testing.T) {
   213  		tests := []struct {
   214  			name             string
   215  			stackPointerCeil uint64
   216  			stackBasePointer uint64
   217  		}{
   218  			{
   219  				name:             "ceil=6/sbp=len-5",
   220  				stackPointerCeil: 6,
   221  				stackBasePointer: defaultStackLen - 5,
   222  			},
   223  			{
   224  				name:             "ceil=10000/sbp=0",
   225  				stackPointerCeil: 10000,
   226  				stackBasePointer: 0,
   227  			},
   228  		}
   229  
   230  		for _, tc := range tests {
   231  			tc := tc
   232  			t.Run(tc.name, func(t *testing.T) {
   233  				env := newCompilerEnvironment()
   234  				compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, nil)
   235  
   236  				err := compiler.compilePreamble()
   237  				require.NoError(t, err)
   238  
   239  				// On the return from grow value stack, we simply return.
   240  				err = compiler.compileReturnFunction()
   241  				require.NoError(t, err)
   242  
   243  				code := asm.CodeSegment{}
   244  				defer func() { require.NoError(t, code.Unmap()) }()
   245  
   246  				// Generate code under test with the given stackPointerCeil.
   247  				compiler.setStackPointerCeil(tc.stackPointerCeil)
   248  				_, err = compiler.compile(code.NextCodeSection())
   249  				require.NoError(t, err)
   250  
   251  				// And run the code with the specified stackBasePointer.
   252  				env.setStackBasePointer(tc.stackBasePointer)
   253  				env.exec(code.Bytes())
   254  
   255  				// Check if the call exits with builtin function call status.
   256  				require.Equal(t, nativeCallStatusCodeCallBuiltInFunction, env.compilerStatus())
   257  
   258  				// Reenter from the return address.
   259  				returnAddress := env.ce.returnAddress
   260  				require.True(t, returnAddress != 0, "returnAddress was zero %d", returnAddress)
   261  				nativecall(returnAddress, env.callEngine(), env.module())
   262  
   263  				// Check the result. This should be "Returned".
   264  				require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
   265  			})
   266  		}
   267  	})
   268  }