github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/wazevo/frontend/frontend.go (about)

     1  // Package frontend implements the translation of WebAssembly to SSA IR using the ssa package.
     2  package frontend
     3  
     4  import (
     5  	"bytes"
     6  
     7  	"github.com/bananabytelabs/wazero/internal/engine/wazevo/ssa"
     8  	"github.com/bananabytelabs/wazero/internal/engine/wazevo/wazevoapi"
     9  	"github.com/bananabytelabs/wazero/internal/wasm"
    10  )
    11  
    12  // Compiler is in charge of lowering Wasm to SSA IR, and does the optimization
    13  // on top of it in architecture-independent way.
    14  type Compiler struct {
    15  	// Per-module data that is used across all functions.
    16  
    17  	m      *wasm.Module
    18  	offset *wazevoapi.ModuleContextOffsetData
    19  	// ssaBuilder is a ssa.Builder used by this frontend.
    20  	ssaBuilder             ssa.Builder
    21  	signatures             map[*wasm.FunctionType]*ssa.Signature
    22  	listenerSignatures     map[*wasm.FunctionType][2]*ssa.Signature
    23  	memoryGrowSig          ssa.Signature
    24  	checkModuleExitCodeSig ssa.Signature
    25  	tableGrowSig           ssa.Signature
    26  	refFuncSig             ssa.Signature
    27  	memmoveSig             ssa.Signature
    28  	checkModuleExitCodeArg [1]ssa.Value
    29  	ensureTermination      bool
    30  
    31  	// Followings are reset by per function.
    32  
    33  	// wasmLocalToVariable maps the index (considered as wasm.Index of locals)
    34  	// to the corresponding ssa.Variable.
    35  	wasmLocalToVariable                   map[wasm.Index]ssa.Variable
    36  	wasmLocalFunctionIndex                wasm.Index
    37  	wasmFunctionTypeIndex                 wasm.Index
    38  	wasmFunctionTyp                       *wasm.FunctionType
    39  	wasmFunctionLocalTypes                []wasm.ValueType
    40  	wasmFunctionBody                      []byte
    41  	wasmFunctionBodyOffsetInCodeSection   uint64
    42  	memoryBaseVariable, memoryLenVariable ssa.Variable
    43  	needMemory                            bool
    44  	globalVariables                       []ssa.Variable
    45  	globalVariablesTypes                  []ssa.Type
    46  	mutableGlobalVariablesIndexes         []wasm.Index // index to ^.
    47  	needListener                          bool
    48  	needSourceOffsetInfo                  bool
    49  	// br is reused during lowering.
    50  	br            *bytes.Reader
    51  	loweringState loweringState
    52  
    53  	knownSafeBounds    []knownSafeBound
    54  	knownSafeBoundsSet []ssa.ValueID
    55  
    56  	execCtxPtrValue, moduleCtxPtrValue ssa.Value
    57  }
    58  
    59  type knownSafeBound struct {
    60  	bound        uint64
    61  	absoluteAddr ssa.Value
    62  }
    63  
    64  // NewFrontendCompiler returns a frontend Compiler.
    65  func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData, ensureTermination bool, listenerOn bool, sourceInfo bool) *Compiler {
    66  	c := &Compiler{
    67  		m:                    m,
    68  		ssaBuilder:           ssaBuilder,
    69  		br:                   bytes.NewReader(nil),
    70  		wasmLocalToVariable:  make(map[wasm.Index]ssa.Variable),
    71  		offset:               offset,
    72  		ensureTermination:    ensureTermination,
    73  		needSourceOffsetInfo: sourceInfo,
    74  	}
    75  	c.declareSignatures(listenerOn)
    76  	return c
    77  }
    78  
    79  func (c *Compiler) declareSignatures(listenerOn bool) {
    80  	m := c.m
    81  	c.signatures = make(map[*wasm.FunctionType]*ssa.Signature, len(m.TypeSection)+2)
    82  	if listenerOn {
    83  		c.listenerSignatures = make(map[*wasm.FunctionType][2]*ssa.Signature, len(m.TypeSection))
    84  	}
    85  	for i := range m.TypeSection {
    86  		wasmSig := &m.TypeSection[i]
    87  		sig := SignatureForWasmFunctionType(wasmSig)
    88  		sig.ID = ssa.SignatureID(i)
    89  		c.signatures[wasmSig] = &sig
    90  		c.ssaBuilder.DeclareSignature(&sig)
    91  
    92  		if listenerOn {
    93  			beforeSig, afterSig := SignatureForListener(wasmSig)
    94  			beforeSig.ID = ssa.SignatureID(i) + ssa.SignatureID(len(m.TypeSection))
    95  			afterSig.ID = ssa.SignatureID(i) + ssa.SignatureID(len(m.TypeSection))*2
    96  			c.listenerSignatures[wasmSig] = [2]*ssa.Signature{beforeSig, afterSig}
    97  			c.ssaBuilder.DeclareSignature(beforeSig)
    98  			c.ssaBuilder.DeclareSignature(afterSig)
    99  		}
   100  	}
   101  
   102  	begin := ssa.SignatureID(len(m.TypeSection))
   103  	if listenerOn {
   104  		begin *= 3
   105  	}
   106  	c.memoryGrowSig = ssa.Signature{
   107  		ID: begin,
   108  		// Takes execution context and the page size to grow.
   109  		Params: []ssa.Type{ssa.TypeI64, ssa.TypeI32},
   110  		// Returns the previous page size.
   111  		Results: []ssa.Type{ssa.TypeI32},
   112  	}
   113  	c.ssaBuilder.DeclareSignature(&c.memoryGrowSig)
   114  
   115  	c.checkModuleExitCodeSig = ssa.Signature{
   116  		ID: c.memoryGrowSig.ID + 1,
   117  		// Only takes execution context.
   118  		Params: []ssa.Type{ssa.TypeI64},
   119  	}
   120  	c.ssaBuilder.DeclareSignature(&c.checkModuleExitCodeSig)
   121  
   122  	c.tableGrowSig = ssa.Signature{
   123  		ID:     c.checkModuleExitCodeSig.ID + 1,
   124  		Params: []ssa.Type{ssa.TypeI64 /* exec context */, ssa.TypeI32 /* table index */, ssa.TypeI32 /* num */, ssa.TypeI64 /* ref */},
   125  		// Returns the previous size.
   126  		Results: []ssa.Type{ssa.TypeI32},
   127  	}
   128  	c.ssaBuilder.DeclareSignature(&c.tableGrowSig)
   129  
   130  	c.refFuncSig = ssa.Signature{
   131  		ID:     c.tableGrowSig.ID + 1,
   132  		Params: []ssa.Type{ssa.TypeI64 /* exec context */, ssa.TypeI32 /* func index */},
   133  		// Returns the function reference.
   134  		Results: []ssa.Type{ssa.TypeI64},
   135  	}
   136  	c.ssaBuilder.DeclareSignature(&c.refFuncSig)
   137  
   138  	c.memmoveSig = ssa.Signature{
   139  		ID: c.refFuncSig.ID + 1,
   140  		// dst, src, and the byte count.
   141  		Params: []ssa.Type{ssa.TypeI64, ssa.TypeI64, ssa.TypeI32},
   142  	}
   143  	c.ssaBuilder.DeclareSignature(&c.memmoveSig)
   144  }
   145  
   146  // SignatureForWasmFunctionType returns the ssa.Signature for the given wasm.FunctionType.
   147  func SignatureForWasmFunctionType(typ *wasm.FunctionType) ssa.Signature {
   148  	sig := ssa.Signature{
   149  		// +2 to pass moduleContextPtr and executionContextPtr. See the inline comment LowerToSSA.
   150  		Params:  make([]ssa.Type, len(typ.Params)+2),
   151  		Results: make([]ssa.Type, len(typ.Results)),
   152  	}
   153  	sig.Params[0] = executionContextPtrTyp
   154  	sig.Params[1] = moduleContextPtrTyp
   155  	for j, typ := range typ.Params {
   156  		sig.Params[j+2] = WasmTypeToSSAType(typ)
   157  	}
   158  	for j, typ := range typ.Results {
   159  		sig.Results[j] = WasmTypeToSSAType(typ)
   160  	}
   161  	return sig
   162  }
   163  
   164  // Init initializes the state of frontendCompiler and make it ready for a next function.
   165  func (c *Compiler) Init(idx, typIndex wasm.Index, typ *wasm.FunctionType, localTypes []wasm.ValueType, body []byte, needListener bool, bodyOffsetInCodeSection uint64) {
   166  	c.ssaBuilder.Init(c.signatures[typ])
   167  	c.loweringState.reset()
   168  
   169  	c.wasmFunctionTypeIndex = typIndex
   170  	c.wasmLocalFunctionIndex = idx
   171  	c.wasmFunctionTyp = typ
   172  	c.wasmFunctionLocalTypes = localTypes
   173  	c.wasmFunctionBody = body
   174  	c.wasmFunctionBodyOffsetInCodeSection = bodyOffsetInCodeSection
   175  	c.needListener = needListener
   176  }
   177  
   178  // Note: this assumes 64-bit platform (I believe we won't have 32-bit backend ;)).
   179  const executionContextPtrTyp, moduleContextPtrTyp = ssa.TypeI64, ssa.TypeI64
   180  
   181  // LowerToSSA lowers the current function to SSA function which will be held by ssaBuilder.
   182  // After calling this, the caller will be able to access the SSA info in *Compiler.ssaBuilder.
   183  //
   184  // Note that this only does the naive lowering, and do not do any optimization, instead the caller is expected to do so.
   185  func (c *Compiler) LowerToSSA() {
   186  	builder := c.ssaBuilder
   187  
   188  	// Set up the entry block.
   189  	entryBlock := builder.AllocateBasicBlock()
   190  	builder.SetCurrentBlock(entryBlock)
   191  
   192  	// Functions always take two parameters in addition to Wasm-level parameters:
   193  	//
   194  	//  1. executionContextPtr: pointer to the *executionContext in wazevo package.
   195  	//    This will be used to exit the execution in the face of trap, plus used for host function calls.
   196  	//
   197  	// 	2. moduleContextPtr: pointer to the *moduleContextOpaque in wazevo package.
   198  	//	  This will be used to access memory, etc. Also, this will be used during host function calls.
   199  	//
   200  	// Note: it's clear that sometimes a function won't need them. For example,
   201  	//  if the function doesn't trap and doesn't make function call, then
   202  	// 	we might be able to eliminate the parameter. However, if that function
   203  	//	can be called via call_indirect, then we cannot eliminate because the
   204  	//  signature won't match with the expected one.
   205  	// TODO: maybe there's some way to do this optimization without glitches, but so far I have no clue about the feasibility.
   206  	//
   207  	// Note: In Wasmtime or many other runtimes, moduleContextPtr is called "vmContext". Also note that `moduleContextPtr`
   208  	//  is wazero-specific since other runtimes can naturally use the OS-level signal to do this job thanks to the fact that
   209  	//  they can use native stack vs wazero cannot use Go-routine stack and have to use Go-runtime allocated []byte as a stack.
   210  	c.execCtxPtrValue = entryBlock.AddParam(builder, executionContextPtrTyp)
   211  	c.moduleCtxPtrValue = entryBlock.AddParam(builder, moduleContextPtrTyp)
   212  	builder.AnnotateValue(c.execCtxPtrValue, "exec_ctx")
   213  	builder.AnnotateValue(c.moduleCtxPtrValue, "module_ctx")
   214  
   215  	for i, typ := range c.wasmFunctionTyp.Params {
   216  		st := WasmTypeToSSAType(typ)
   217  		variable := builder.DeclareVariable(st)
   218  		value := entryBlock.AddParam(builder, st)
   219  		builder.DefineVariable(variable, value, entryBlock)
   220  		c.wasmLocalToVariable[wasm.Index(i)] = variable
   221  	}
   222  	c.declareWasmLocals(entryBlock)
   223  	c.declareNecessaryVariables()
   224  
   225  	c.lowerBody(entryBlock)
   226  }
   227  
   228  // localVariable returns the SSA variable for the given Wasm local index.
   229  func (c *Compiler) localVariable(index wasm.Index) ssa.Variable {
   230  	return c.wasmLocalToVariable[index]
   231  }
   232  
   233  // declareWasmLocals declares the SSA variables for the Wasm locals.
   234  func (c *Compiler) declareWasmLocals(entry ssa.BasicBlock) {
   235  	localCount := wasm.Index(len(c.wasmFunctionTyp.Params))
   236  	for i, typ := range c.wasmFunctionLocalTypes {
   237  		st := WasmTypeToSSAType(typ)
   238  		variable := c.ssaBuilder.DeclareVariable(st)
   239  		c.wasmLocalToVariable[wasm.Index(i)+localCount] = variable
   240  
   241  		zeroInst := c.ssaBuilder.AllocateInstruction()
   242  		switch st {
   243  		case ssa.TypeI32:
   244  			zeroInst.AsIconst32(0)
   245  		case ssa.TypeI64:
   246  			zeroInst.AsIconst64(0)
   247  		case ssa.TypeF32:
   248  			zeroInst.AsF32const(0)
   249  		case ssa.TypeF64:
   250  			zeroInst.AsF64const(0)
   251  		case ssa.TypeV128:
   252  			zeroInst.AsVconst(0, 0)
   253  		default:
   254  			panic("TODO: " + wasm.ValueTypeName(typ))
   255  		}
   256  
   257  		c.ssaBuilder.InsertInstruction(zeroInst)
   258  		value := zeroInst.Return()
   259  		c.ssaBuilder.DefineVariable(variable, value, entry)
   260  	}
   261  }
   262  
   263  func (c *Compiler) declareNecessaryVariables() {
   264  	c.needMemory = c.m.ImportMemoryCount > 0 || c.m.MemorySection != nil
   265  	if c.needMemory {
   266  		c.memoryBaseVariable = c.ssaBuilder.DeclareVariable(ssa.TypeI64)
   267  		c.memoryLenVariable = c.ssaBuilder.DeclareVariable(ssa.TypeI64)
   268  	}
   269  
   270  	c.globalVariables = c.globalVariables[:0]
   271  	c.mutableGlobalVariablesIndexes = c.mutableGlobalVariablesIndexes[:0]
   272  	c.globalVariablesTypes = c.globalVariablesTypes[:0]
   273  	for _, imp := range c.m.ImportSection {
   274  		if imp.Type == wasm.ExternTypeGlobal {
   275  			desc := imp.DescGlobal
   276  			c.declareWasmGlobal(desc.ValType, desc.Mutable)
   277  		}
   278  	}
   279  	for _, g := range c.m.GlobalSection {
   280  		desc := g.Type
   281  		c.declareWasmGlobal(desc.ValType, desc.Mutable)
   282  	}
   283  
   284  	// TODO: add tables.
   285  }
   286  
   287  func (c *Compiler) declareWasmGlobal(typ wasm.ValueType, mutable bool) {
   288  	var st ssa.Type
   289  	switch typ {
   290  	case wasm.ValueTypeI32:
   291  		st = ssa.TypeI32
   292  	case wasm.ValueTypeI64,
   293  		// Both externref and funcref are represented as I64 since we only support 64-bit platforms.
   294  		wasm.ValueTypeExternref, wasm.ValueTypeFuncref:
   295  		st = ssa.TypeI64
   296  	case wasm.ValueTypeF32:
   297  		st = ssa.TypeF32
   298  	case wasm.ValueTypeF64:
   299  		st = ssa.TypeF64
   300  	case wasm.ValueTypeV128:
   301  		st = ssa.TypeV128
   302  	default:
   303  		panic("TODO: " + wasm.ValueTypeName(typ))
   304  	}
   305  	v := c.ssaBuilder.DeclareVariable(st)
   306  	index := wasm.Index(len(c.globalVariables))
   307  	c.globalVariables = append(c.globalVariables, v)
   308  	c.globalVariablesTypes = append(c.globalVariablesTypes, st)
   309  	if mutable {
   310  		c.mutableGlobalVariablesIndexes = append(c.mutableGlobalVariablesIndexes, index)
   311  	}
   312  }
   313  
   314  // WasmTypeToSSAType converts wasm.ValueType to ssa.Type.
   315  func WasmTypeToSSAType(vt wasm.ValueType) ssa.Type {
   316  	switch vt {
   317  	case wasm.ValueTypeI32:
   318  		return ssa.TypeI32
   319  	case wasm.ValueTypeI64,
   320  		// Both externref and funcref are represented as I64 since we only support 64-bit platforms.
   321  		wasm.ValueTypeExternref, wasm.ValueTypeFuncref:
   322  		return ssa.TypeI64
   323  	case wasm.ValueTypeF32:
   324  		return ssa.TypeF32
   325  	case wasm.ValueTypeF64:
   326  		return ssa.TypeF64
   327  	case wasm.ValueTypeV128:
   328  		return ssa.TypeV128
   329  	default:
   330  		panic("TODO: " + wasm.ValueTypeName(vt))
   331  	}
   332  }
   333  
   334  // addBlockParamsFromWasmTypes adds the block parameters to the given block.
   335  func (c *Compiler) addBlockParamsFromWasmTypes(tps []wasm.ValueType, blk ssa.BasicBlock) {
   336  	for _, typ := range tps {
   337  		st := WasmTypeToSSAType(typ)
   338  		blk.AddParam(c.ssaBuilder, st)
   339  	}
   340  }
   341  
   342  // formatBuilder outputs the constructed SSA function as a string with a source information.
   343  func (c *Compiler) formatBuilder() string {
   344  	return c.ssaBuilder.Format()
   345  }
   346  
   347  // SignatureForListener returns the signatures for the listener functions.
   348  func SignatureForListener(wasmSig *wasm.FunctionType) (*ssa.Signature, *ssa.Signature) {
   349  	beforeSig := &ssa.Signature{}
   350  	beforeSig.Params = make([]ssa.Type, len(wasmSig.Params)+2)
   351  	beforeSig.Params[0] = ssa.TypeI64 // Execution context.
   352  	beforeSig.Params[1] = ssa.TypeI32 // Function index.
   353  	for i, p := range wasmSig.Params {
   354  		beforeSig.Params[i+2] = WasmTypeToSSAType(p)
   355  	}
   356  	afterSig := &ssa.Signature{}
   357  	afterSig.Params = make([]ssa.Type, len(wasmSig.Results)+2)
   358  	afterSig.Params[0] = ssa.TypeI64 // Execution context.
   359  	afterSig.Params[1] = ssa.TypeI32 // Function index.
   360  	for i, p := range wasmSig.Results {
   361  		afterSig.Params[i+2] = WasmTypeToSSAType(p)
   362  	}
   363  	return beforeSig, afterSig
   364  }
   365  
   366  // isBoundSafe returns true if the given value is known to be safe to access up to the given bound.
   367  func (c *Compiler) getKnownSafeBound(v ssa.ValueID) *knownSafeBound {
   368  	if int(v) >= len(c.knownSafeBounds) {
   369  		return nil
   370  	}
   371  	return &c.knownSafeBounds[v]
   372  }
   373  
   374  // recordKnownSafeBound records the given safe bound for the given value.
   375  func (c *Compiler) recordKnownSafeBound(v ssa.ValueID, safeBound uint64, absoluteAddr ssa.Value) {
   376  	if int(v) >= len(c.knownSafeBounds) {
   377  		c.knownSafeBounds = append(c.knownSafeBounds, make([]knownSafeBound, v+1)...)
   378  	}
   379  
   380  	if exiting := c.knownSafeBounds[v]; exiting.bound == 0 {
   381  		c.knownSafeBounds[v] = knownSafeBound{
   382  			bound:        safeBound,
   383  			absoluteAddr: absoluteAddr,
   384  		}
   385  		c.knownSafeBoundsSet = append(c.knownSafeBoundsSet, v)
   386  	} else if safeBound > exiting.bound {
   387  		c.knownSafeBounds[v].bound = safeBound
   388  	}
   389  }
   390  
   391  // clearSafeBounds clears the known safe bounds. This must be called
   392  // after the compilation of each block.
   393  func (c *Compiler) clearSafeBounds() {
   394  	for _, v := range c.knownSafeBoundsSet {
   395  		ptr := &c.knownSafeBounds[v]
   396  		ptr.bound = 0
   397  	}
   398  	c.knownSafeBoundsSet = c.knownSafeBoundsSet[:0]
   399  }
   400  
   401  func (k *knownSafeBound) valid() bool {
   402  	return k != nil && k.bound > 0
   403  }