github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/compiler/compiler_initialization_test.go (about) 1 package compiler 2 3 import ( 4 "fmt" 5 "reflect" 6 "testing" 7 "unsafe" 8 9 "github.com/bananabytelabs/wazero/internal/asm" 10 "github.com/bananabytelabs/wazero/internal/testing/require" 11 "github.com/bananabytelabs/wazero/internal/wasm" 12 "github.com/bananabytelabs/wazero/internal/wazeroir" 13 ) 14 15 func TestCompiler_compileModuleContextInitialization(t *testing.T) { 16 tests := []struct { 17 name string 18 moduleInstance *wasm.ModuleInstance 19 }{ 20 { 21 name: "no nil", 22 moduleInstance: &wasm.ModuleInstance{ 23 Globals: []*wasm.GlobalInstance{{Val: 100}}, 24 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)}, 25 Tables: []*wasm.TableInstance{ 26 {References: make([]wasm.Reference, 20)}, 27 {References: make([]wasm.Reference, 10)}, 28 }, 29 TypeIDs: make([]wasm.FunctionTypeID, 10), 30 DataInstances: make([][]byte, 10), 31 ElementInstances: make([]wasm.ElementInstance, 10), 32 }, 33 }, 34 { 35 name: "element instances nil", 36 moduleInstance: &wasm.ModuleInstance{ 37 Globals: []*wasm.GlobalInstance{{Val: 100}}, 38 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)}, 39 Tables: []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}}, 40 TypeIDs: make([]wasm.FunctionTypeID, 10), 41 DataInstances: make([][]byte, 10), 42 ElementInstances: nil, 43 }, 44 }, 45 { 46 name: "data instances nil", 47 moduleInstance: &wasm.ModuleInstance{ 48 Globals: []*wasm.GlobalInstance{{Val: 100}}, 49 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)}, 50 Tables: []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}}, 51 TypeIDs: make([]wasm.FunctionTypeID, 10), 52 DataInstances: nil, 53 ElementInstances: make([]wasm.ElementInstance, 10), 54 }, 55 }, 56 { 57 name: "globals nil", 58 moduleInstance: &wasm.ModuleInstance{ 59 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)}, 60 Tables: []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}}, 61 TypeIDs: make([]wasm.FunctionTypeID, 10), 62 DataInstances: make([][]byte, 10), 63 ElementInstances: make([]wasm.ElementInstance, 10), 64 }, 65 }, 66 { 67 name: "memory nil", 68 moduleInstance: &wasm.ModuleInstance{ 69 Globals: []*wasm.GlobalInstance{{Val: 100}}, 70 Tables: []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}}, 71 TypeIDs: make([]wasm.FunctionTypeID, 10), 72 DataInstances: make([][]byte, 10), 73 ElementInstances: make([]wasm.ElementInstance, 10), 74 }, 75 }, 76 { 77 name: "table nil", 78 moduleInstance: &wasm.ModuleInstance{ 79 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)}, 80 Tables: []*wasm.TableInstance{{References: nil}}, 81 TypeIDs: make([]wasm.FunctionTypeID, 10), 82 DataInstances: make([][]byte, 10), 83 ElementInstances: make([]wasm.ElementInstance, 10), 84 }, 85 }, 86 { 87 name: "table empty", 88 moduleInstance: &wasm.ModuleInstance{ 89 Tables: []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}}, 90 TypeIDs: make([]wasm.FunctionTypeID, 10), 91 DataInstances: make([][]byte, 10), 92 ElementInstances: make([]wasm.ElementInstance, 10), 93 }, 94 }, 95 { 96 name: "memory zero length", 97 moduleInstance: &wasm.ModuleInstance{ 98 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 0)}, 99 }, 100 }, 101 { 102 name: "all nil except mod engine", 103 moduleInstance: &wasm.ModuleInstance{}, 104 }, 105 } 106 107 for _, tt := range tests { 108 tc := tt 109 t.Run(tc.name, func(t *testing.T) { 110 env := newCompilerEnvironment() 111 env.moduleInstance = tc.moduleInstance 112 ce := env.callEngine() 113 114 ir := &wazeroir.CompilationResult{ 115 HasMemory: tc.moduleInstance.MemoryInstance != nil, 116 HasTable: len(tc.moduleInstance.Tables) > 0, 117 HasDataInstances: len(tc.moduleInstance.DataInstances) > 0, 118 HasElementInstances: len(tc.moduleInstance.ElementInstances) > 0, 119 } 120 for _, g := range tc.moduleInstance.Globals { 121 ir.Globals = append(ir.Globals, g.Type) 122 } 123 compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, ir) 124 me := &moduleEngine{functions: make([]function, 10)} 125 tc.moduleInstance.Engine = me 126 127 err := compiler.compileModuleContextInitialization() 128 require.NoError(t, err) 129 require.Zero(t, len(compiler.runtimeValueLocationStack().usedRegisters.list()), "expected no usedRegisters") 130 131 compiler.compileExitFromNativeCode(nativeCallStatusCodeReturned) 132 133 code := asm.CodeSegment{} 134 defer func() { require.NoError(t, code.Unmap()) }() 135 136 // Generate the code under test. 137 _, err = compiler.compile(code.NextCodeSection()) 138 require.NoError(t, err) 139 140 env.exec(code.Bytes()) 141 142 // Check the exit status. 143 require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus()) 144 145 // Check if the fields of callEngine.moduleContext are updated. 146 bufSliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Globals)) 147 require.Equal(t, bufSliceHeader.Data, ce.moduleContext.globalElement0Address) 148 149 if tc.moduleInstance.MemoryInstance != nil { 150 bufSliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.MemoryInstance.Buffer)) 151 require.Equal(t, uint64(bufSliceHeader.Len), ce.moduleContext.memorySliceLen) 152 require.Equal(t, bufSliceHeader.Data, ce.moduleContext.memoryElement0Address) 153 require.Equal(t, tc.moduleInstance.MemoryInstance, ce.moduleContext.memoryInstance) 154 } 155 156 if len(tc.moduleInstance.Tables) > 0 { 157 tableHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Tables)) 158 require.Equal(t, tableHeader.Data, ce.moduleContext.tablesElement0Address) 159 require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.TypeIDs[0])), ce.moduleContext.typeIDsElement0Address) 160 require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.Tables[0])), ce.moduleContext.tablesElement0Address) 161 } 162 163 if len(tc.moduleInstance.DataInstances) > 0 { 164 dataInstancesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.DataInstances)) 165 require.Equal(t, dataInstancesHeader.Data, ce.moduleContext.dataInstancesElement0Address) 166 require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.DataInstances[0])), ce.moduleContext.dataInstancesElement0Address) 167 } 168 169 if len(tc.moduleInstance.ElementInstances) > 0 { 170 elementInstancesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.ElementInstances)) 171 require.Equal(t, elementInstancesHeader.Data, ce.moduleContext.elementInstancesElement0Address) 172 require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.ElementInstances[0])), ce.moduleContext.elementInstancesElement0Address) 173 } 174 175 require.Equal(t, uintptr(unsafe.Pointer(&me.functions[0])), ce.moduleContext.functionsElement0Address) 176 }) 177 } 178 } 179 180 func TestCompiler_compileMaybeGrowStack(t *testing.T) { 181 t.Run("not grow", func(t *testing.T) { 182 const stackPointerCeil = 5 183 for _, baseOffset := range []uint64{5, 10, 20} { 184 t.Run(fmt.Sprintf("%d", baseOffset), func(t *testing.T) { 185 env := newCompilerEnvironment() 186 compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, nil) 187 188 err := compiler.compilePreamble() 189 require.NoError(t, err) 190 191 stackLen := uint64(len(env.stack())) 192 stackBasePointer := stackLen - baseOffset // Ceil <= stackLen - stackBasePointer = no need to grow! 193 compiler.assignStackPointerCeil(stackPointerCeil) 194 env.setStackBasePointer(stackBasePointer) 195 196 compiler.compileExitFromNativeCode(nativeCallStatusCodeReturned) 197 198 code := asm.CodeSegment{} 199 defer func() { require.NoError(t, code.Unmap()) }() 200 201 // Generate and run the code under test. 202 _, err = compiler.compile(code.NextCodeSection()) 203 require.NoError(t, err) 204 env.exec(code.Bytes()) 205 206 // The status code must be "Returned", not "BuiltinFunctionCall". 207 require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus()) 208 }) 209 } 210 }) 211 212 defaultStackLen := uint64(initialStackSize) 213 t.Run("grow", func(t *testing.T) { 214 tests := []struct { 215 name string 216 stackPointerCeil uint64 217 stackBasePointer uint64 218 }{ 219 { 220 name: "ceil=6/sbp=len-5", 221 stackPointerCeil: 6, 222 stackBasePointer: defaultStackLen - 5, 223 }, 224 { 225 name: "ceil=10000/sbp=0", 226 stackPointerCeil: 10000, 227 stackBasePointer: 0, 228 }, 229 } 230 231 for _, tc := range tests { 232 tc := tc 233 t.Run(tc.name, func(t *testing.T) { 234 env := newCompilerEnvironment() 235 compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, nil) 236 237 err := compiler.compilePreamble() 238 require.NoError(t, err) 239 240 // On the return from grow value stack, we simply return. 241 err = compiler.compileReturnFunction() 242 require.NoError(t, err) 243 244 code := asm.CodeSegment{} 245 defer func() { require.NoError(t, code.Unmap()) }() 246 247 // Generate code under test with the given stackPointerCeil. 248 compiler.setStackPointerCeil(tc.stackPointerCeil) 249 _, err = compiler.compile(code.NextCodeSection()) 250 require.NoError(t, err) 251 252 // And run the code with the specified stackBasePointer. 253 env.setStackBasePointer(tc.stackBasePointer) 254 env.exec(code.Bytes()) 255 256 // Check if the call exits with builtin function call status. 257 require.Equal(t, nativeCallStatusCodeCallBuiltInFunction, env.compilerStatus()) 258 259 // Reenter from the return address. 260 returnAddress := env.ce.returnAddress 261 require.True(t, returnAddress != 0, "returnAddress was zero %d", returnAddress) 262 nativecall(returnAddress, env.callEngine(), env.module()) 263 264 // Check the result. This should be "Returned". 265 require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus()) 266 }) 267 } 268 }) 269 }