github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/compiler/compiler_initialization_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"testing"
     7  	"unsafe"
     8  
     9  	"github.com/bananabytelabs/wazero/internal/asm"
    10  	"github.com/bananabytelabs/wazero/internal/testing/require"
    11  	"github.com/bananabytelabs/wazero/internal/wasm"
    12  	"github.com/bananabytelabs/wazero/internal/wazeroir"
    13  )
    14  
    15  func TestCompiler_compileModuleContextInitialization(t *testing.T) {
    16  	tests := []struct {
    17  		name           string
    18  		moduleInstance *wasm.ModuleInstance
    19  	}{
    20  		{
    21  			name: "no nil",
    22  			moduleInstance: &wasm.ModuleInstance{
    23  				Globals:        []*wasm.GlobalInstance{{Val: 100}},
    24  				MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    25  				Tables: []*wasm.TableInstance{
    26  					{References: make([]wasm.Reference, 20)},
    27  					{References: make([]wasm.Reference, 10)},
    28  				},
    29  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    30  				DataInstances:    make([][]byte, 10),
    31  				ElementInstances: make([]wasm.ElementInstance, 10),
    32  			},
    33  		},
    34  		{
    35  			name: "element instances nil",
    36  			moduleInstance: &wasm.ModuleInstance{
    37  				Globals:          []*wasm.GlobalInstance{{Val: 100}},
    38  				MemoryInstance:   &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    39  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    40  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    41  				DataInstances:    make([][]byte, 10),
    42  				ElementInstances: nil,
    43  			},
    44  		},
    45  		{
    46  			name: "data instances nil",
    47  			moduleInstance: &wasm.ModuleInstance{
    48  				Globals:          []*wasm.GlobalInstance{{Val: 100}},
    49  				MemoryInstance:   &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    50  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    51  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    52  				DataInstances:    nil,
    53  				ElementInstances: make([]wasm.ElementInstance, 10),
    54  			},
    55  		},
    56  		{
    57  			name: "globals nil",
    58  			moduleInstance: &wasm.ModuleInstance{
    59  				MemoryInstance:   &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    60  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    61  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    62  				DataInstances:    make([][]byte, 10),
    63  				ElementInstances: make([]wasm.ElementInstance, 10),
    64  			},
    65  		},
    66  		{
    67  			name: "memory nil",
    68  			moduleInstance: &wasm.ModuleInstance{
    69  				Globals:          []*wasm.GlobalInstance{{Val: 100}},
    70  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    71  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    72  				DataInstances:    make([][]byte, 10),
    73  				ElementInstances: make([]wasm.ElementInstance, 10),
    74  			},
    75  		},
    76  		{
    77  			name: "table nil",
    78  			moduleInstance: &wasm.ModuleInstance{
    79  				MemoryInstance:   &wasm.MemoryInstance{Buffer: make([]byte, 10)},
    80  				Tables:           []*wasm.TableInstance{{References: nil}},
    81  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    82  				DataInstances:    make([][]byte, 10),
    83  				ElementInstances: make([]wasm.ElementInstance, 10),
    84  			},
    85  		},
    86  		{
    87  			name: "table empty",
    88  			moduleInstance: &wasm.ModuleInstance{
    89  				Tables:           []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
    90  				TypeIDs:          make([]wasm.FunctionTypeID, 10),
    91  				DataInstances:    make([][]byte, 10),
    92  				ElementInstances: make([]wasm.ElementInstance, 10),
    93  			},
    94  		},
    95  		{
    96  			name: "memory zero length",
    97  			moduleInstance: &wasm.ModuleInstance{
    98  				MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 0)},
    99  			},
   100  		},
   101  		{
   102  			name:           "all nil except mod engine",
   103  			moduleInstance: &wasm.ModuleInstance{},
   104  		},
   105  	}
   106  
   107  	for _, tt := range tests {
   108  		tc := tt
   109  		t.Run(tc.name, func(t *testing.T) {
   110  			env := newCompilerEnvironment()
   111  			env.moduleInstance = tc.moduleInstance
   112  			ce := env.callEngine()
   113  
   114  			ir := &wazeroir.CompilationResult{
   115  				HasMemory:           tc.moduleInstance.MemoryInstance != nil,
   116  				HasTable:            len(tc.moduleInstance.Tables) > 0,
   117  				HasDataInstances:    len(tc.moduleInstance.DataInstances) > 0,
   118  				HasElementInstances: len(tc.moduleInstance.ElementInstances) > 0,
   119  			}
   120  			for _, g := range tc.moduleInstance.Globals {
   121  				ir.Globals = append(ir.Globals, g.Type)
   122  			}
   123  			compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, ir)
   124  			me := &moduleEngine{functions: make([]function, 10)}
   125  			tc.moduleInstance.Engine = me
   126  
   127  			err := compiler.compileModuleContextInitialization()
   128  			require.NoError(t, err)
   129  			require.Zero(t, len(compiler.runtimeValueLocationStack().usedRegisters.list()), "expected no usedRegisters")
   130  
   131  			compiler.compileExitFromNativeCode(nativeCallStatusCodeReturned)
   132  
   133  			code := asm.CodeSegment{}
   134  			defer func() { require.NoError(t, code.Unmap()) }()
   135  
   136  			// Generate the code under test.
   137  			_, err = compiler.compile(code.NextCodeSection())
   138  			require.NoError(t, err)
   139  
   140  			env.exec(code.Bytes())
   141  
   142  			// Check the exit status.
   143  			require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
   144  
   145  			// Check if the fields of callEngine.moduleContext are updated.
   146  			bufSliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Globals))
   147  			require.Equal(t, bufSliceHeader.Data, ce.moduleContext.globalElement0Address)
   148  
   149  			if tc.moduleInstance.MemoryInstance != nil {
   150  				bufSliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.MemoryInstance.Buffer))
   151  				require.Equal(t, uint64(bufSliceHeader.Len), ce.moduleContext.memorySliceLen)
   152  				require.Equal(t, bufSliceHeader.Data, ce.moduleContext.memoryElement0Address)
   153  				require.Equal(t, tc.moduleInstance.MemoryInstance, ce.moduleContext.memoryInstance)
   154  			}
   155  
   156  			if len(tc.moduleInstance.Tables) > 0 {
   157  				tableHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Tables))
   158  				require.Equal(t, tableHeader.Data, ce.moduleContext.tablesElement0Address)
   159  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.TypeIDs[0])), ce.moduleContext.typeIDsElement0Address)
   160  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.Tables[0])), ce.moduleContext.tablesElement0Address)
   161  			}
   162  
   163  			if len(tc.moduleInstance.DataInstances) > 0 {
   164  				dataInstancesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.DataInstances))
   165  				require.Equal(t, dataInstancesHeader.Data, ce.moduleContext.dataInstancesElement0Address)
   166  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.DataInstances[0])), ce.moduleContext.dataInstancesElement0Address)
   167  			}
   168  
   169  			if len(tc.moduleInstance.ElementInstances) > 0 {
   170  				elementInstancesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.ElementInstances))
   171  				require.Equal(t, elementInstancesHeader.Data, ce.moduleContext.elementInstancesElement0Address)
   172  				require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.ElementInstances[0])), ce.moduleContext.elementInstancesElement0Address)
   173  			}
   174  
   175  			require.Equal(t, uintptr(unsafe.Pointer(&me.functions[0])), ce.moduleContext.functionsElement0Address)
   176  		})
   177  	}
   178  }
   179  
   180  func TestCompiler_compileMaybeGrowStack(t *testing.T) {
   181  	t.Run("not grow", func(t *testing.T) {
   182  		const stackPointerCeil = 5
   183  		for _, baseOffset := range []uint64{5, 10, 20} {
   184  			t.Run(fmt.Sprintf("%d", baseOffset), func(t *testing.T) {
   185  				env := newCompilerEnvironment()
   186  				compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, nil)
   187  
   188  				err := compiler.compilePreamble()
   189  				require.NoError(t, err)
   190  
   191  				stackLen := uint64(len(env.stack()))
   192  				stackBasePointer := stackLen - baseOffset // Ceil <= stackLen - stackBasePointer = no need to grow!
   193  				compiler.assignStackPointerCeil(stackPointerCeil)
   194  				env.setStackBasePointer(stackBasePointer)
   195  
   196  				compiler.compileExitFromNativeCode(nativeCallStatusCodeReturned)
   197  
   198  				code := asm.CodeSegment{}
   199  				defer func() { require.NoError(t, code.Unmap()) }()
   200  
   201  				// Generate and run the code under test.
   202  				_, err = compiler.compile(code.NextCodeSection())
   203  				require.NoError(t, err)
   204  				env.exec(code.Bytes())
   205  
   206  				// The status code must be "Returned", not "BuiltinFunctionCall".
   207  				require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
   208  			})
   209  		}
   210  	})
   211  
   212  	defaultStackLen := uint64(initialStackSize)
   213  	t.Run("grow", func(t *testing.T) {
   214  		tests := []struct {
   215  			name             string
   216  			stackPointerCeil uint64
   217  			stackBasePointer uint64
   218  		}{
   219  			{
   220  				name:             "ceil=6/sbp=len-5",
   221  				stackPointerCeil: 6,
   222  				stackBasePointer: defaultStackLen - 5,
   223  			},
   224  			{
   225  				name:             "ceil=10000/sbp=0",
   226  				stackPointerCeil: 10000,
   227  				stackBasePointer: 0,
   228  			},
   229  		}
   230  
   231  		for _, tc := range tests {
   232  			tc := tc
   233  			t.Run(tc.name, func(t *testing.T) {
   234  				env := newCompilerEnvironment()
   235  				compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, nil)
   236  
   237  				err := compiler.compilePreamble()
   238  				require.NoError(t, err)
   239  
   240  				// On the return from grow value stack, we simply return.
   241  				err = compiler.compileReturnFunction()
   242  				require.NoError(t, err)
   243  
   244  				code := asm.CodeSegment{}
   245  				defer func() { require.NoError(t, code.Unmap()) }()
   246  
   247  				// Generate code under test with the given stackPointerCeil.
   248  				compiler.setStackPointerCeil(tc.stackPointerCeil)
   249  				_, err = compiler.compile(code.NextCodeSection())
   250  				require.NoError(t, err)
   251  
   252  				// And run the code with the specified stackBasePointer.
   253  				env.setStackBasePointer(tc.stackBasePointer)
   254  				env.exec(code.Bytes())
   255  
   256  				// Check if the call exits with builtin function call status.
   257  				require.Equal(t, nativeCallStatusCodeCallBuiltInFunction, env.compilerStatus())
   258  
   259  				// Reenter from the return address.
   260  				returnAddress := env.ce.returnAddress
   261  				require.True(t, returnAddress != 0, "returnAddress was zero %d", returnAddress)
   262  				nativecall(returnAddress, env.callEngine(), env.module())
   263  
   264  				// Check the result. This should be "Returned".
   265  				require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
   266  			})
   267  		}
   268  	})
   269  }