github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/internal/engine/wazevo/frontend/frontend.go (about)

     1  // Package frontend implements the translation of WebAssembly to SSA IR using the ssa package.
     2  package frontend
     3  
     4  import (
     5  	"bytes"
     6  
     7  	"github.com/wasilibs/wazerox/internal/engine/wazevo/ssa"
     8  	"github.com/wasilibs/wazerox/internal/engine/wazevo/wazevoapi"
     9  	"github.com/wasilibs/wazerox/internal/wasm"
    10  )
    11  
    12  // Compiler is in charge of lowering Wasm to SSA IR, and does the optimization
    13  // on top of it in architecture-independent way.
    14  type Compiler struct {
    15  	// Per-module data that is used across all functions.
    16  
    17  	m      *wasm.Module
    18  	offset *wazevoapi.ModuleContextOffsetData
    19  	// ssaBuilder is a ssa.Builder used by this frontend.
    20  	ssaBuilder             ssa.Builder
    21  	signatures             map[*wasm.FunctionType]*ssa.Signature
    22  	listenerSignatures     map[*wasm.FunctionType][2]*ssa.Signature
    23  	memoryGrowSig          ssa.Signature
    24  	checkModuleExitCodeSig ssa.Signature
    25  	tableGrowSig           ssa.Signature
    26  	refFuncSig             ssa.Signature
    27  	memmoveSig             ssa.Signature
    28  	checkModuleExitCodeArg [1]ssa.Value
    29  	ensureTermination      bool
    30  
    31  	// Followings are reset by per function.
    32  
    33  	// wasmLocalToVariable maps the index (considered as wasm.Index of locals)
    34  	// to the corresponding ssa.Variable.
    35  	wasmLocalToVariable                   map[wasm.Index]ssa.Variable
    36  	wasmLocalFunctionIndex                wasm.Index
    37  	wasmFunctionTypeIndex                 wasm.Index
    38  	wasmFunctionTyp                       *wasm.FunctionType
    39  	wasmFunctionLocalTypes                []wasm.ValueType
    40  	wasmFunctionBody                      []byte
    41  	wasmFunctionBodyOffsetInCodeSection   uint64
    42  	memoryBaseVariable, memoryLenVariable ssa.Variable
    43  	needMemory                            bool
    44  	globalVariables                       []ssa.Variable
    45  	globalVariablesTypes                  []ssa.Type
    46  	mutableGlobalVariablesIndexes         []wasm.Index // index to ^.
    47  	needListener                          bool
    48  	needSourceOffsetInfo                  bool
    49  	// br is reused during lowering.
    50  	br            *bytes.Reader
    51  	loweringState loweringState
    52  
    53  	execCtxPtrValue, moduleCtxPtrValue ssa.Value
    54  }
    55  
    56  // NewFrontendCompiler returns a frontend Compiler.
    57  func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData, ensureTermination bool, listenerOn bool, sourceInfo bool) *Compiler {
    58  	c := &Compiler{
    59  		m:                    m,
    60  		ssaBuilder:           ssaBuilder,
    61  		br:                   bytes.NewReader(nil),
    62  		wasmLocalToVariable:  make(map[wasm.Index]ssa.Variable),
    63  		offset:               offset,
    64  		ensureTermination:    ensureTermination,
    65  		needSourceOffsetInfo: sourceInfo,
    66  	}
    67  	c.declareSignatures(listenerOn)
    68  	return c
    69  }
    70  
    71  func (c *Compiler) declareSignatures(listenerOn bool) {
    72  	m := c.m
    73  	c.signatures = make(map[*wasm.FunctionType]*ssa.Signature, len(m.TypeSection)+2)
    74  	if listenerOn {
    75  		c.listenerSignatures = make(map[*wasm.FunctionType][2]*ssa.Signature, len(m.TypeSection))
    76  	}
    77  	for i := range m.TypeSection {
    78  		wasmSig := &m.TypeSection[i]
    79  		sig := SignatureForWasmFunctionType(wasmSig)
    80  		sig.ID = ssa.SignatureID(i)
    81  		c.signatures[wasmSig] = &sig
    82  		c.ssaBuilder.DeclareSignature(&sig)
    83  
    84  		if listenerOn {
    85  			beforeSig, afterSig := SignatureForListener(wasmSig)
    86  			beforeSig.ID = ssa.SignatureID(i) + ssa.SignatureID(len(m.TypeSection))
    87  			afterSig.ID = ssa.SignatureID(i) + ssa.SignatureID(len(m.TypeSection))*2
    88  			c.listenerSignatures[wasmSig] = [2]*ssa.Signature{beforeSig, afterSig}
    89  			c.ssaBuilder.DeclareSignature(beforeSig)
    90  			c.ssaBuilder.DeclareSignature(afterSig)
    91  		}
    92  	}
    93  
    94  	begin := ssa.SignatureID(len(m.TypeSection))
    95  	if listenerOn {
    96  		begin *= 3
    97  	}
    98  	c.memoryGrowSig = ssa.Signature{
    99  		ID: begin,
   100  		// Takes execution context and the page size to grow.
   101  		Params: []ssa.Type{ssa.TypeI64, ssa.TypeI32},
   102  		// Returns the previous page size.
   103  		Results: []ssa.Type{ssa.TypeI32},
   104  	}
   105  	c.ssaBuilder.DeclareSignature(&c.memoryGrowSig)
   106  
   107  	c.checkModuleExitCodeSig = ssa.Signature{
   108  		ID: c.memoryGrowSig.ID + 1,
   109  		// Only takes execution context.
   110  		Params: []ssa.Type{ssa.TypeI64},
   111  	}
   112  	c.ssaBuilder.DeclareSignature(&c.checkModuleExitCodeSig)
   113  
   114  	c.tableGrowSig = ssa.Signature{
   115  		ID:     c.checkModuleExitCodeSig.ID + 1,
   116  		Params: []ssa.Type{ssa.TypeI64 /* exec context */, ssa.TypeI32 /* table index */, ssa.TypeI32 /* num */, ssa.TypeI64 /* ref */},
   117  		// Returns the previous size.
   118  		Results: []ssa.Type{ssa.TypeI32},
   119  	}
   120  	c.ssaBuilder.DeclareSignature(&c.tableGrowSig)
   121  
   122  	c.refFuncSig = ssa.Signature{
   123  		ID:     c.tableGrowSig.ID + 1,
   124  		Params: []ssa.Type{ssa.TypeI64 /* exec context */, ssa.TypeI32 /* func index */},
   125  		// Returns the function reference.
   126  		Results: []ssa.Type{ssa.TypeI64},
   127  	}
   128  	c.ssaBuilder.DeclareSignature(&c.refFuncSig)
   129  
   130  	c.memmoveSig = ssa.Signature{
   131  		ID: c.refFuncSig.ID + 1,
   132  		// dst, src, and the byte count.
   133  		Params: []ssa.Type{ssa.TypeI64, ssa.TypeI64, ssa.TypeI32},
   134  	}
   135  	c.ssaBuilder.DeclareSignature(&c.memmoveSig)
   136  }
   137  
   138  // SignatureForWasmFunctionType returns the ssa.Signature for the given wasm.FunctionType.
   139  func SignatureForWasmFunctionType(typ *wasm.FunctionType) ssa.Signature {
   140  	sig := ssa.Signature{
   141  		// +2 to pass moduleContextPtr and executionContextPtr. See the inline comment LowerToSSA.
   142  		Params:  make([]ssa.Type, len(typ.Params)+2),
   143  		Results: make([]ssa.Type, len(typ.Results)),
   144  	}
   145  	sig.Params[0] = executionContextPtrTyp
   146  	sig.Params[1] = moduleContextPtrTyp
   147  	for j, typ := range typ.Params {
   148  		sig.Params[j+2] = WasmTypeToSSAType(typ)
   149  	}
   150  	for j, typ := range typ.Results {
   151  		sig.Results[j] = WasmTypeToSSAType(typ)
   152  	}
   153  	return sig
   154  }
   155  
   156  // Init initializes the state of frontendCompiler and make it ready for a next function.
   157  func (c *Compiler) Init(idx, typIndex wasm.Index, typ *wasm.FunctionType, localTypes []wasm.ValueType, body []byte, needListener bool, bodyOffsetInCodeSection uint64) {
   158  	c.ssaBuilder.Init(c.signatures[typ])
   159  	c.loweringState.reset()
   160  
   161  	c.wasmFunctionTypeIndex = typIndex
   162  	c.wasmLocalFunctionIndex = idx
   163  	c.wasmFunctionTyp = typ
   164  	c.wasmFunctionLocalTypes = localTypes
   165  	c.wasmFunctionBody = body
   166  	c.wasmFunctionBodyOffsetInCodeSection = bodyOffsetInCodeSection
   167  	c.needListener = needListener
   168  }
   169  
   170  // Note: this assumes 64-bit platform (I believe we won't have 32-bit backend ;)).
   171  const executionContextPtrTyp, moduleContextPtrTyp = ssa.TypeI64, ssa.TypeI64
   172  
   173  // LowerToSSA lowers the current function to SSA function which will be held by ssaBuilder.
   174  // After calling this, the caller will be able to access the SSA info in *Compiler.ssaBuilder.
   175  //
   176  // Note that this only does the naive lowering, and do not do any optimization, instead the caller is expected to do so.
   177  func (c *Compiler) LowerToSSA() {
   178  	builder := c.ssaBuilder
   179  
   180  	// Set up the entry block.
   181  	entryBlock := builder.AllocateBasicBlock()
   182  	builder.SetCurrentBlock(entryBlock)
   183  
   184  	// Functions always take two parameters in addition to Wasm-level parameters:
   185  	//
   186  	//  1. executionContextPtr: pointer to the *executionContext in wazevo package.
   187  	//    This will be used to exit the execution in the face of trap, plus used for host function calls.
   188  	//
   189  	// 	2. moduleContextPtr: pointer to the *moduleContextOpaque in wazevo package.
   190  	//	  This will be used to access memory, etc. Also, this will be used during host function calls.
   191  	//
   192  	// Note: it's clear that sometimes a function won't need them. For example,
   193  	//  if the function doesn't trap and doesn't make function call, then
   194  	// 	we might be able to eliminate the parameter. However, if that function
   195  	//	can be called via call_indirect, then we cannot eliminate because the
   196  	//  signature won't match with the expected one.
   197  	// TODO: maybe there's some way to do this optimization without glitches, but so far I have no clue about the feasibility.
   198  	//
   199  	// Note: In Wasmtime or many other runtimes, moduleContextPtr is called "vmContext". Also note that `moduleContextPtr`
   200  	//  is wazero-specific since other runtimes can naturally use the OS-level signal to do this job thanks to the fact that
   201  	//  they can use native stack vs wazero cannot use Go-routine stack and have to use Go-runtime allocated []byte as a stack.
   202  	c.execCtxPtrValue = entryBlock.AddParam(builder, executionContextPtrTyp)
   203  	c.moduleCtxPtrValue = entryBlock.AddParam(builder, moduleContextPtrTyp)
   204  	builder.AnnotateValue(c.execCtxPtrValue, "exec_ctx")
   205  	builder.AnnotateValue(c.moduleCtxPtrValue, "module_ctx")
   206  
   207  	for i, typ := range c.wasmFunctionTyp.Params {
   208  		st := WasmTypeToSSAType(typ)
   209  		variable := builder.DeclareVariable(st)
   210  		value := entryBlock.AddParam(builder, st)
   211  		builder.DefineVariable(variable, value, entryBlock)
   212  		c.wasmLocalToVariable[wasm.Index(i)] = variable
   213  	}
   214  	c.declareWasmLocals(entryBlock)
   215  	c.declareNecessaryVariables()
   216  
   217  	c.lowerBody(entryBlock)
   218  }
   219  
   220  // localVariable returns the SSA variable for the given Wasm local index.
   221  func (c *Compiler) localVariable(index wasm.Index) ssa.Variable {
   222  	return c.wasmLocalToVariable[index]
   223  }
   224  
   225  // declareWasmLocals declares the SSA variables for the Wasm locals.
   226  func (c *Compiler) declareWasmLocals(entry ssa.BasicBlock) {
   227  	localCount := wasm.Index(len(c.wasmFunctionTyp.Params))
   228  	for i, typ := range c.wasmFunctionLocalTypes {
   229  		st := WasmTypeToSSAType(typ)
   230  		variable := c.ssaBuilder.DeclareVariable(st)
   231  		c.wasmLocalToVariable[wasm.Index(i)+localCount] = variable
   232  
   233  		zeroInst := c.ssaBuilder.AllocateInstruction()
   234  		switch st {
   235  		case ssa.TypeI32:
   236  			zeroInst.AsIconst32(0)
   237  		case ssa.TypeI64:
   238  			zeroInst.AsIconst64(0)
   239  		case ssa.TypeF32:
   240  			zeroInst.AsF32const(0)
   241  		case ssa.TypeF64:
   242  			zeroInst.AsF64const(0)
   243  		case ssa.TypeV128:
   244  			zeroInst.AsVconst(0, 0)
   245  		default:
   246  			panic("TODO: " + wasm.ValueTypeName(typ))
   247  		}
   248  
   249  		c.ssaBuilder.InsertInstruction(zeroInst)
   250  		value := zeroInst.Return()
   251  		c.ssaBuilder.DefineVariable(variable, value, entry)
   252  	}
   253  }
   254  
   255  func (c *Compiler) declareNecessaryVariables() {
   256  	c.needMemory = c.m.ImportMemoryCount > 0 || c.m.MemorySection != nil
   257  	if c.needMemory {
   258  		c.memoryBaseVariable = c.ssaBuilder.DeclareVariable(ssa.TypeI64)
   259  		c.memoryLenVariable = c.ssaBuilder.DeclareVariable(ssa.TypeI64)
   260  	}
   261  
   262  	c.globalVariables = c.globalVariables[:0]
   263  	c.mutableGlobalVariablesIndexes = c.mutableGlobalVariablesIndexes[:0]
   264  	c.globalVariablesTypes = c.globalVariablesTypes[:0]
   265  	for _, imp := range c.m.ImportSection {
   266  		if imp.Type == wasm.ExternTypeGlobal {
   267  			desc := imp.DescGlobal
   268  			c.declareWasmGlobal(desc.ValType, desc.Mutable)
   269  		}
   270  	}
   271  	for _, g := range c.m.GlobalSection {
   272  		desc := g.Type
   273  		c.declareWasmGlobal(desc.ValType, desc.Mutable)
   274  	}
   275  
   276  	// TODO: add tables.
   277  }
   278  
   279  func (c *Compiler) declareWasmGlobal(typ wasm.ValueType, mutable bool) {
   280  	var st ssa.Type
   281  	switch typ {
   282  	case wasm.ValueTypeI32:
   283  		st = ssa.TypeI32
   284  	case wasm.ValueTypeI64,
   285  		// Both externref and funcref are represented as I64 since we only support 64-bit platforms.
   286  		wasm.ValueTypeExternref, wasm.ValueTypeFuncref:
   287  		st = ssa.TypeI64
   288  	case wasm.ValueTypeF32:
   289  		st = ssa.TypeF32
   290  	case wasm.ValueTypeF64:
   291  		st = ssa.TypeF64
   292  	case wasm.ValueTypeV128:
   293  		st = ssa.TypeV128
   294  	default:
   295  		panic("TODO: " + wasm.ValueTypeName(typ))
   296  	}
   297  	v := c.ssaBuilder.DeclareVariable(st)
   298  	index := wasm.Index(len(c.globalVariables))
   299  	c.globalVariables = append(c.globalVariables, v)
   300  	c.globalVariablesTypes = append(c.globalVariablesTypes, st)
   301  	if mutable {
   302  		c.mutableGlobalVariablesIndexes = append(c.mutableGlobalVariablesIndexes, index)
   303  	}
   304  }
   305  
   306  // WasmTypeToSSAType converts wasm.ValueType to ssa.Type.
   307  func WasmTypeToSSAType(vt wasm.ValueType) ssa.Type {
   308  	switch vt {
   309  	case wasm.ValueTypeI32:
   310  		return ssa.TypeI32
   311  	case wasm.ValueTypeI64,
   312  		// Both externref and funcref are represented as I64 since we only support 64-bit platforms.
   313  		wasm.ValueTypeExternref, wasm.ValueTypeFuncref:
   314  		return ssa.TypeI64
   315  	case wasm.ValueTypeF32:
   316  		return ssa.TypeF32
   317  	case wasm.ValueTypeF64:
   318  		return ssa.TypeF64
   319  	case wasm.ValueTypeV128:
   320  		return ssa.TypeV128
   321  	default:
   322  		panic("TODO: " + wasm.ValueTypeName(vt))
   323  	}
   324  }
   325  
   326  // addBlockParamsFromWasmTypes adds the block parameters to the given block.
   327  func (c *Compiler) addBlockParamsFromWasmTypes(tps []wasm.ValueType, blk ssa.BasicBlock) {
   328  	for _, typ := range tps {
   329  		st := WasmTypeToSSAType(typ)
   330  		blk.AddParam(c.ssaBuilder, st)
   331  	}
   332  }
   333  
   334  // formatBuilder outputs the constructed SSA function as a string with a source information.
   335  func (c *Compiler) formatBuilder() string {
   336  	return c.ssaBuilder.Format()
   337  }
   338  
   339  // SignatureForListener returns the signatures for the listener functions.
   340  func SignatureForListener(wasmSig *wasm.FunctionType) (*ssa.Signature, *ssa.Signature) {
   341  	beforeSig := &ssa.Signature{}
   342  	beforeSig.Params = make([]ssa.Type, len(wasmSig.Params)+2)
   343  	beforeSig.Params[0] = ssa.TypeI64 // Execution context.
   344  	beforeSig.Params[1] = ssa.TypeI32 // Function index.
   345  	for i, p := range wasmSig.Params {
   346  		beforeSig.Params[i+2] = WasmTypeToSSAType(p)
   347  	}
   348  	afterSig := &ssa.Signature{}
   349  	afterSig.Params = make([]ssa.Type, len(wasmSig.Results)+2)
   350  	afterSig.Params[0] = ssa.TypeI64 // Execution context.
   351  	afterSig.Params[1] = ssa.TypeI32 // Function index.
   352  	for i, p := range wasmSig.Results {
   353  		afterSig.Params[i+2] = WasmTypeToSSAType(p)
   354  	}
   355  	return beforeSig, afterSig
   356  }