github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/wazevo/frontend/frontend.go (about) 1 // Package frontend implements the translation of WebAssembly to SSA IR using the ssa package. 2 package frontend 3 4 import ( 5 "bytes" 6 7 "github.com/bananabytelabs/wazero/internal/engine/wazevo/ssa" 8 "github.com/bananabytelabs/wazero/internal/engine/wazevo/wazevoapi" 9 "github.com/bananabytelabs/wazero/internal/wasm" 10 ) 11 12 // Compiler is in charge of lowering Wasm to SSA IR, and does the optimization 13 // on top of it in architecture-independent way. 14 type Compiler struct { 15 // Per-module data that is used across all functions. 16 17 m *wasm.Module 18 offset *wazevoapi.ModuleContextOffsetData 19 // ssaBuilder is a ssa.Builder used by this frontend. 20 ssaBuilder ssa.Builder 21 signatures map[*wasm.FunctionType]*ssa.Signature 22 listenerSignatures map[*wasm.FunctionType][2]*ssa.Signature 23 memoryGrowSig ssa.Signature 24 checkModuleExitCodeSig ssa.Signature 25 tableGrowSig ssa.Signature 26 refFuncSig ssa.Signature 27 memmoveSig ssa.Signature 28 checkModuleExitCodeArg [1]ssa.Value 29 ensureTermination bool 30 31 // Followings are reset by per function. 32 33 // wasmLocalToVariable maps the index (considered as wasm.Index of locals) 34 // to the corresponding ssa.Variable. 35 wasmLocalToVariable map[wasm.Index]ssa.Variable 36 wasmLocalFunctionIndex wasm.Index 37 wasmFunctionTypeIndex wasm.Index 38 wasmFunctionTyp *wasm.FunctionType 39 wasmFunctionLocalTypes []wasm.ValueType 40 wasmFunctionBody []byte 41 wasmFunctionBodyOffsetInCodeSection uint64 42 memoryBaseVariable, memoryLenVariable ssa.Variable 43 needMemory bool 44 globalVariables []ssa.Variable 45 globalVariablesTypes []ssa.Type 46 mutableGlobalVariablesIndexes []wasm.Index // index to ^. 47 needListener bool 48 needSourceOffsetInfo bool 49 // br is reused during lowering. 50 br *bytes.Reader 51 loweringState loweringState 52 53 knownSafeBounds []knownSafeBound 54 knownSafeBoundsSet []ssa.ValueID 55 56 execCtxPtrValue, moduleCtxPtrValue ssa.Value 57 } 58 59 type knownSafeBound struct { 60 bound uint64 61 absoluteAddr ssa.Value 62 } 63 64 // NewFrontendCompiler returns a frontend Compiler. 65 func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData, ensureTermination bool, listenerOn bool, sourceInfo bool) *Compiler { 66 c := &Compiler{ 67 m: m, 68 ssaBuilder: ssaBuilder, 69 br: bytes.NewReader(nil), 70 wasmLocalToVariable: make(map[wasm.Index]ssa.Variable), 71 offset: offset, 72 ensureTermination: ensureTermination, 73 needSourceOffsetInfo: sourceInfo, 74 } 75 c.declareSignatures(listenerOn) 76 return c 77 } 78 79 func (c *Compiler) declareSignatures(listenerOn bool) { 80 m := c.m 81 c.signatures = make(map[*wasm.FunctionType]*ssa.Signature, len(m.TypeSection)+2) 82 if listenerOn { 83 c.listenerSignatures = make(map[*wasm.FunctionType][2]*ssa.Signature, len(m.TypeSection)) 84 } 85 for i := range m.TypeSection { 86 wasmSig := &m.TypeSection[i] 87 sig := SignatureForWasmFunctionType(wasmSig) 88 sig.ID = ssa.SignatureID(i) 89 c.signatures[wasmSig] = &sig 90 c.ssaBuilder.DeclareSignature(&sig) 91 92 if listenerOn { 93 beforeSig, afterSig := SignatureForListener(wasmSig) 94 beforeSig.ID = ssa.SignatureID(i) + ssa.SignatureID(len(m.TypeSection)) 95 afterSig.ID = ssa.SignatureID(i) + ssa.SignatureID(len(m.TypeSection))*2 96 c.listenerSignatures[wasmSig] = [2]*ssa.Signature{beforeSig, afterSig} 97 c.ssaBuilder.DeclareSignature(beforeSig) 98 c.ssaBuilder.DeclareSignature(afterSig) 99 } 100 } 101 102 begin := ssa.SignatureID(len(m.TypeSection)) 103 if listenerOn { 104 begin *= 3 105 } 106 c.memoryGrowSig = ssa.Signature{ 107 ID: begin, 108 // Takes execution context and the page size to grow. 109 Params: []ssa.Type{ssa.TypeI64, ssa.TypeI32}, 110 // Returns the previous page size. 111 Results: []ssa.Type{ssa.TypeI32}, 112 } 113 c.ssaBuilder.DeclareSignature(&c.memoryGrowSig) 114 115 c.checkModuleExitCodeSig = ssa.Signature{ 116 ID: c.memoryGrowSig.ID + 1, 117 // Only takes execution context. 118 Params: []ssa.Type{ssa.TypeI64}, 119 } 120 c.ssaBuilder.DeclareSignature(&c.checkModuleExitCodeSig) 121 122 c.tableGrowSig = ssa.Signature{ 123 ID: c.checkModuleExitCodeSig.ID + 1, 124 Params: []ssa.Type{ssa.TypeI64 /* exec context */, ssa.TypeI32 /* table index */, ssa.TypeI32 /* num */, ssa.TypeI64 /* ref */}, 125 // Returns the previous size. 126 Results: []ssa.Type{ssa.TypeI32}, 127 } 128 c.ssaBuilder.DeclareSignature(&c.tableGrowSig) 129 130 c.refFuncSig = ssa.Signature{ 131 ID: c.tableGrowSig.ID + 1, 132 Params: []ssa.Type{ssa.TypeI64 /* exec context */, ssa.TypeI32 /* func index */}, 133 // Returns the function reference. 134 Results: []ssa.Type{ssa.TypeI64}, 135 } 136 c.ssaBuilder.DeclareSignature(&c.refFuncSig) 137 138 c.memmoveSig = ssa.Signature{ 139 ID: c.refFuncSig.ID + 1, 140 // dst, src, and the byte count. 141 Params: []ssa.Type{ssa.TypeI64, ssa.TypeI64, ssa.TypeI32}, 142 } 143 c.ssaBuilder.DeclareSignature(&c.memmoveSig) 144 } 145 146 // SignatureForWasmFunctionType returns the ssa.Signature for the given wasm.FunctionType. 147 func SignatureForWasmFunctionType(typ *wasm.FunctionType) ssa.Signature { 148 sig := ssa.Signature{ 149 // +2 to pass moduleContextPtr and executionContextPtr. See the inline comment LowerToSSA. 150 Params: make([]ssa.Type, len(typ.Params)+2), 151 Results: make([]ssa.Type, len(typ.Results)), 152 } 153 sig.Params[0] = executionContextPtrTyp 154 sig.Params[1] = moduleContextPtrTyp 155 for j, typ := range typ.Params { 156 sig.Params[j+2] = WasmTypeToSSAType(typ) 157 } 158 for j, typ := range typ.Results { 159 sig.Results[j] = WasmTypeToSSAType(typ) 160 } 161 return sig 162 } 163 164 // Init initializes the state of frontendCompiler and make it ready for a next function. 165 func (c *Compiler) Init(idx, typIndex wasm.Index, typ *wasm.FunctionType, localTypes []wasm.ValueType, body []byte, needListener bool, bodyOffsetInCodeSection uint64) { 166 c.ssaBuilder.Init(c.signatures[typ]) 167 c.loweringState.reset() 168 169 c.wasmFunctionTypeIndex = typIndex 170 c.wasmLocalFunctionIndex = idx 171 c.wasmFunctionTyp = typ 172 c.wasmFunctionLocalTypes = localTypes 173 c.wasmFunctionBody = body 174 c.wasmFunctionBodyOffsetInCodeSection = bodyOffsetInCodeSection 175 c.needListener = needListener 176 } 177 178 // Note: this assumes 64-bit platform (I believe we won't have 32-bit backend ;)). 179 const executionContextPtrTyp, moduleContextPtrTyp = ssa.TypeI64, ssa.TypeI64 180 181 // LowerToSSA lowers the current function to SSA function which will be held by ssaBuilder. 182 // After calling this, the caller will be able to access the SSA info in *Compiler.ssaBuilder. 183 // 184 // Note that this only does the naive lowering, and do not do any optimization, instead the caller is expected to do so. 185 func (c *Compiler) LowerToSSA() { 186 builder := c.ssaBuilder 187 188 // Set up the entry block. 189 entryBlock := builder.AllocateBasicBlock() 190 builder.SetCurrentBlock(entryBlock) 191 192 // Functions always take two parameters in addition to Wasm-level parameters: 193 // 194 // 1. executionContextPtr: pointer to the *executionContext in wazevo package. 195 // This will be used to exit the execution in the face of trap, plus used for host function calls. 196 // 197 // 2. moduleContextPtr: pointer to the *moduleContextOpaque in wazevo package. 198 // This will be used to access memory, etc. Also, this will be used during host function calls. 199 // 200 // Note: it's clear that sometimes a function won't need them. For example, 201 // if the function doesn't trap and doesn't make function call, then 202 // we might be able to eliminate the parameter. However, if that function 203 // can be called via call_indirect, then we cannot eliminate because the 204 // signature won't match with the expected one. 205 // TODO: maybe there's some way to do this optimization without glitches, but so far I have no clue about the feasibility. 206 // 207 // Note: In Wasmtime or many other runtimes, moduleContextPtr is called "vmContext". Also note that `moduleContextPtr` 208 // is wazero-specific since other runtimes can naturally use the OS-level signal to do this job thanks to the fact that 209 // they can use native stack vs wazero cannot use Go-routine stack and have to use Go-runtime allocated []byte as a stack. 210 c.execCtxPtrValue = entryBlock.AddParam(builder, executionContextPtrTyp) 211 c.moduleCtxPtrValue = entryBlock.AddParam(builder, moduleContextPtrTyp) 212 builder.AnnotateValue(c.execCtxPtrValue, "exec_ctx") 213 builder.AnnotateValue(c.moduleCtxPtrValue, "module_ctx") 214 215 for i, typ := range c.wasmFunctionTyp.Params { 216 st := WasmTypeToSSAType(typ) 217 variable := builder.DeclareVariable(st) 218 value := entryBlock.AddParam(builder, st) 219 builder.DefineVariable(variable, value, entryBlock) 220 c.wasmLocalToVariable[wasm.Index(i)] = variable 221 } 222 c.declareWasmLocals(entryBlock) 223 c.declareNecessaryVariables() 224 225 c.lowerBody(entryBlock) 226 } 227 228 // localVariable returns the SSA variable for the given Wasm local index. 229 func (c *Compiler) localVariable(index wasm.Index) ssa.Variable { 230 return c.wasmLocalToVariable[index] 231 } 232 233 // declareWasmLocals declares the SSA variables for the Wasm locals. 234 func (c *Compiler) declareWasmLocals(entry ssa.BasicBlock) { 235 localCount := wasm.Index(len(c.wasmFunctionTyp.Params)) 236 for i, typ := range c.wasmFunctionLocalTypes { 237 st := WasmTypeToSSAType(typ) 238 variable := c.ssaBuilder.DeclareVariable(st) 239 c.wasmLocalToVariable[wasm.Index(i)+localCount] = variable 240 241 zeroInst := c.ssaBuilder.AllocateInstruction() 242 switch st { 243 case ssa.TypeI32: 244 zeroInst.AsIconst32(0) 245 case ssa.TypeI64: 246 zeroInst.AsIconst64(0) 247 case ssa.TypeF32: 248 zeroInst.AsF32const(0) 249 case ssa.TypeF64: 250 zeroInst.AsF64const(0) 251 case ssa.TypeV128: 252 zeroInst.AsVconst(0, 0) 253 default: 254 panic("TODO: " + wasm.ValueTypeName(typ)) 255 } 256 257 c.ssaBuilder.InsertInstruction(zeroInst) 258 value := zeroInst.Return() 259 c.ssaBuilder.DefineVariable(variable, value, entry) 260 } 261 } 262 263 func (c *Compiler) declareNecessaryVariables() { 264 c.needMemory = c.m.ImportMemoryCount > 0 || c.m.MemorySection != nil 265 if c.needMemory { 266 c.memoryBaseVariable = c.ssaBuilder.DeclareVariable(ssa.TypeI64) 267 c.memoryLenVariable = c.ssaBuilder.DeclareVariable(ssa.TypeI64) 268 } 269 270 c.globalVariables = c.globalVariables[:0] 271 c.mutableGlobalVariablesIndexes = c.mutableGlobalVariablesIndexes[:0] 272 c.globalVariablesTypes = c.globalVariablesTypes[:0] 273 for _, imp := range c.m.ImportSection { 274 if imp.Type == wasm.ExternTypeGlobal { 275 desc := imp.DescGlobal 276 c.declareWasmGlobal(desc.ValType, desc.Mutable) 277 } 278 } 279 for _, g := range c.m.GlobalSection { 280 desc := g.Type 281 c.declareWasmGlobal(desc.ValType, desc.Mutable) 282 } 283 284 // TODO: add tables. 285 } 286 287 func (c *Compiler) declareWasmGlobal(typ wasm.ValueType, mutable bool) { 288 var st ssa.Type 289 switch typ { 290 case wasm.ValueTypeI32: 291 st = ssa.TypeI32 292 case wasm.ValueTypeI64, 293 // Both externref and funcref are represented as I64 since we only support 64-bit platforms. 294 wasm.ValueTypeExternref, wasm.ValueTypeFuncref: 295 st = ssa.TypeI64 296 case wasm.ValueTypeF32: 297 st = ssa.TypeF32 298 case wasm.ValueTypeF64: 299 st = ssa.TypeF64 300 case wasm.ValueTypeV128: 301 st = ssa.TypeV128 302 default: 303 panic("TODO: " + wasm.ValueTypeName(typ)) 304 } 305 v := c.ssaBuilder.DeclareVariable(st) 306 index := wasm.Index(len(c.globalVariables)) 307 c.globalVariables = append(c.globalVariables, v) 308 c.globalVariablesTypes = append(c.globalVariablesTypes, st) 309 if mutable { 310 c.mutableGlobalVariablesIndexes = append(c.mutableGlobalVariablesIndexes, index) 311 } 312 } 313 314 // WasmTypeToSSAType converts wasm.ValueType to ssa.Type. 315 func WasmTypeToSSAType(vt wasm.ValueType) ssa.Type { 316 switch vt { 317 case wasm.ValueTypeI32: 318 return ssa.TypeI32 319 case wasm.ValueTypeI64, 320 // Both externref and funcref are represented as I64 since we only support 64-bit platforms. 321 wasm.ValueTypeExternref, wasm.ValueTypeFuncref: 322 return ssa.TypeI64 323 case wasm.ValueTypeF32: 324 return ssa.TypeF32 325 case wasm.ValueTypeF64: 326 return ssa.TypeF64 327 case wasm.ValueTypeV128: 328 return ssa.TypeV128 329 default: 330 panic("TODO: " + wasm.ValueTypeName(vt)) 331 } 332 } 333 334 // addBlockParamsFromWasmTypes adds the block parameters to the given block. 335 func (c *Compiler) addBlockParamsFromWasmTypes(tps []wasm.ValueType, blk ssa.BasicBlock) { 336 for _, typ := range tps { 337 st := WasmTypeToSSAType(typ) 338 blk.AddParam(c.ssaBuilder, st) 339 } 340 } 341 342 // formatBuilder outputs the constructed SSA function as a string with a source information. 343 func (c *Compiler) formatBuilder() string { 344 return c.ssaBuilder.Format() 345 } 346 347 // SignatureForListener returns the signatures for the listener functions. 348 func SignatureForListener(wasmSig *wasm.FunctionType) (*ssa.Signature, *ssa.Signature) { 349 beforeSig := &ssa.Signature{} 350 beforeSig.Params = make([]ssa.Type, len(wasmSig.Params)+2) 351 beforeSig.Params[0] = ssa.TypeI64 // Execution context. 352 beforeSig.Params[1] = ssa.TypeI32 // Function index. 353 for i, p := range wasmSig.Params { 354 beforeSig.Params[i+2] = WasmTypeToSSAType(p) 355 } 356 afterSig := &ssa.Signature{} 357 afterSig.Params = make([]ssa.Type, len(wasmSig.Results)+2) 358 afterSig.Params[0] = ssa.TypeI64 // Execution context. 359 afterSig.Params[1] = ssa.TypeI32 // Function index. 360 for i, p := range wasmSig.Results { 361 afterSig.Params[i+2] = WasmTypeToSSAType(p) 362 } 363 return beforeSig, afterSig 364 } 365 366 // isBoundSafe returns true if the given value is known to be safe to access up to the given bound. 367 func (c *Compiler) getKnownSafeBound(v ssa.ValueID) *knownSafeBound { 368 if int(v) >= len(c.knownSafeBounds) { 369 return nil 370 } 371 return &c.knownSafeBounds[v] 372 } 373 374 // recordKnownSafeBound records the given safe bound for the given value. 375 func (c *Compiler) recordKnownSafeBound(v ssa.ValueID, safeBound uint64, absoluteAddr ssa.Value) { 376 if int(v) >= len(c.knownSafeBounds) { 377 c.knownSafeBounds = append(c.knownSafeBounds, make([]knownSafeBound, v+1)...) 378 } 379 380 if exiting := c.knownSafeBounds[v]; exiting.bound == 0 { 381 c.knownSafeBounds[v] = knownSafeBound{ 382 bound: safeBound, 383 absoluteAddr: absoluteAddr, 384 } 385 c.knownSafeBoundsSet = append(c.knownSafeBoundsSet, v) 386 } else if safeBound > exiting.bound { 387 c.knownSafeBounds[v].bound = safeBound 388 } 389 } 390 391 // clearSafeBounds clears the known safe bounds. This must be called 392 // after the compilation of each block. 393 func (c *Compiler) clearSafeBounds() { 394 for _, v := range c.knownSafeBoundsSet { 395 ptr := &c.knownSafeBounds[v] 396 ptr.bound = 0 397 } 398 c.knownSafeBoundsSet = c.knownSafeBoundsSet[:0] 399 } 400 401 func (k *knownSafeBound) valid() bool { 402 return k != nil && k.bound > 0 403 }