github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/engine/wazevo/backend/compiler.go (about)

     1  package backend
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/tetratelabs/wazero/internal/engine/wazevo/backend/regalloc"
     8  	"github.com/tetratelabs/wazero/internal/engine/wazevo/ssa"
     9  	"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
    10  )
    11  
    12  // NewCompiler returns a new Compiler that can generate a machine code.
    13  func NewCompiler(ctx context.Context, mach Machine, builder ssa.Builder) Compiler {
    14  	return newCompiler(ctx, mach, builder)
    15  }
    16  
    17  func newCompiler(_ context.Context, mach Machine, builder ssa.Builder) *compiler {
    18  	argResultInts, argResultFloats := mach.ArgsResultsRegs()
    19  	c := &compiler{
    20  		mach: mach, ssaBuilder: builder,
    21  		nextVRegID:      regalloc.VRegIDNonReservedBegin,
    22  		argResultInts:   argResultInts,
    23  		argResultFloats: argResultFloats,
    24  	}
    25  	mach.SetCompiler(c)
    26  	return c
    27  }
    28  
    29  // Compiler is the backend of wazevo which takes ssa.Builder and Machine,
    30  // use the information there to emit the final machine code.
    31  type Compiler interface {
    32  	// SSABuilder returns the ssa.Builder used by this compiler.
    33  	SSABuilder() ssa.Builder
    34  
    35  	// Compile executes the following steps:
    36  	// 	1. Lower()
    37  	// 	2. RegAlloc()
    38  	// 	3. Finalize()
    39  	// 	4. Encode()
    40  	//
    41  	// Each step can be called individually for testing purpose, therefore they are exposed in this interface too.
    42  	//
    43  	// The returned byte slices are the machine code and the relocation information for the machine code.
    44  	// The caller is responsible for copying them immediately since the compiler may reuse the buffer.
    45  	Compile(ctx context.Context) (_ []byte, _ []RelocationInfo, _ error)
    46  
    47  	// Lower lowers the given ssa.Instruction to the machine-specific instructions.
    48  	Lower()
    49  
    50  	// RegAlloc performs the register allocation after Lower is called.
    51  	RegAlloc()
    52  
    53  	// Finalize performs the finalization of the compilation, including machine code emission.
    54  	// This must be called after RegAlloc.
    55  	Finalize(ctx context.Context) error
    56  
    57  	// Buf returns the buffer of the encoded machine code. This is only used for testing purpose.
    58  	Buf() []byte
    59  
    60  	BufPtr() *[]byte
    61  
    62  	// Format returns the debug string of the current state of the compiler.
    63  	Format() string
    64  
    65  	// Init initializes the internal state of the compiler for the next compilation.
    66  	Init()
    67  
    68  	// AllocateVReg allocates a new virtual register of the given type.
    69  	AllocateVReg(typ ssa.Type) regalloc.VReg
    70  
    71  	// ValueDefinition returns the definition of the given value.
    72  	ValueDefinition(ssa.Value) *SSAValueDefinition
    73  
    74  	// VRegOf returns the virtual register of the given ssa.Value.
    75  	VRegOf(value ssa.Value) regalloc.VReg
    76  
    77  	// TypeOf returns the ssa.Type of the given virtual register.
    78  	TypeOf(regalloc.VReg) ssa.Type
    79  
    80  	// MatchInstr returns true if the given definition is from an instruction with the given opcode, the current group ID,
    81  	// and a refcount of 1. That means, the instruction can be merged/swapped within the current instruction group.
    82  	MatchInstr(def *SSAValueDefinition, opcode ssa.Opcode) bool
    83  
    84  	// MatchInstrOneOf is the same as MatchInstr but for multiple opcodes. If it matches one of ssa.Opcode,
    85  	// this returns the opcode. Otherwise, this returns ssa.OpcodeInvalid.
    86  	//
    87  	// Note: caller should be careful to avoid excessive allocation on opcodes slice.
    88  	MatchInstrOneOf(def *SSAValueDefinition, opcodes []ssa.Opcode) ssa.Opcode
    89  
    90  	// AddRelocationInfo appends the relocation information for the function reference at the current buffer offset.
    91  	AddRelocationInfo(funcRef ssa.FuncRef)
    92  
    93  	// AddSourceOffsetInfo appends the source offset information for the given offset.
    94  	AddSourceOffsetInfo(executableOffset int64, sourceOffset ssa.SourceOffset)
    95  
    96  	// SourceOffsetInfo returns the source offset information for the current buffer offset.
    97  	SourceOffsetInfo() []SourceOffsetInfo
    98  
    99  	// EmitByte appends a byte to the buffer. Used during the code emission.
   100  	EmitByte(b byte)
   101  
   102  	// Emit4Bytes appends 4 bytes to the buffer. Used during the code emission.
   103  	Emit4Bytes(b uint32)
   104  
   105  	// Emit8Bytes appends 8 bytes to the buffer. Used during the code emission.
   106  	Emit8Bytes(b uint64)
   107  
   108  	// GetFunctionABI returns the ABI information for the given signature.
   109  	GetFunctionABI(sig *ssa.Signature) *FunctionABI
   110  }
   111  
   112  // RelocationInfo represents the relocation information for a call instruction.
   113  type RelocationInfo struct {
   114  	// Offset represents the offset from the beginning of the machine code of either a function or the entire module.
   115  	Offset int64
   116  	// Target is the target function of the call instruction.
   117  	FuncRef ssa.FuncRef
   118  }
   119  
   120  // compiler implements Compiler.
   121  type compiler struct {
   122  	mach       Machine
   123  	currentGID ssa.InstructionGroupID
   124  	ssaBuilder ssa.Builder
   125  	// nextVRegID is the next virtual register ID to be allocated.
   126  	nextVRegID regalloc.VRegID
   127  	// ssaValueToVRegs maps ssa.ValueID to regalloc.VReg.
   128  	ssaValueToVRegs [] /* VRegID to */ regalloc.VReg
   129  	// ssaValueDefinitions maps ssa.ValueID to its definition.
   130  	ssaValueDefinitions []SSAValueDefinition
   131  	// ssaValueRefCounts is a cached list obtained by ssa.Builder.ValueRefCounts().
   132  	ssaValueRefCounts []int
   133  	// returnVRegs is the list of virtual registers that store the return values.
   134  	returnVRegs  []regalloc.VReg
   135  	varEdges     [][2]regalloc.VReg
   136  	varEdgeTypes []ssa.Type
   137  	constEdges   []struct {
   138  		cInst *ssa.Instruction
   139  		dst   regalloc.VReg
   140  	}
   141  	vRegSet         []bool
   142  	vRegIDs         []regalloc.VRegID
   143  	tempRegs        []regalloc.VReg
   144  	tmpVals         []ssa.Value
   145  	ssaTypeOfVRegID [] /* VRegID to */ ssa.Type
   146  	buf             []byte
   147  	relocations     []RelocationInfo
   148  	sourceOffsets   []SourceOffsetInfo
   149  	// abis maps ssa.SignatureID to the ABI implementation.
   150  	abis                           []FunctionABI
   151  	argResultInts, argResultFloats []regalloc.RealReg
   152  }
   153  
   154  // SourceOffsetInfo is a data to associate the source offset with the executable offset.
   155  type SourceOffsetInfo struct {
   156  	// SourceOffset is the source offset in the original source code.
   157  	SourceOffset ssa.SourceOffset
   158  	// ExecutableOffset is the offset in the compiled executable.
   159  	ExecutableOffset int64
   160  }
   161  
   162  // Compile implements Compiler.Compile.
   163  func (c *compiler) Compile(ctx context.Context) ([]byte, []RelocationInfo, error) {
   164  	c.Lower()
   165  	if wazevoapi.PrintSSAToBackendIRLowering && wazevoapi.PrintEnabledIndex(ctx) {
   166  		fmt.Printf("[[[after lowering for %s ]]]%s\n", wazevoapi.GetCurrentFunctionName(ctx), c.Format())
   167  	}
   168  	if wazevoapi.DeterministicCompilationVerifierEnabled {
   169  		wazevoapi.VerifyOrSetDeterministicCompilationContextValue(ctx, "After lowering to ISA specific IR", c.Format())
   170  	}
   171  	c.RegAlloc()
   172  	if wazevoapi.PrintRegisterAllocated && wazevoapi.PrintEnabledIndex(ctx) {
   173  		fmt.Printf("[[[after regalloc for %s]]]%s\n", wazevoapi.GetCurrentFunctionName(ctx), c.Format())
   174  	}
   175  	if wazevoapi.DeterministicCompilationVerifierEnabled {
   176  		wazevoapi.VerifyOrSetDeterministicCompilationContextValue(ctx, "After Register Allocation", c.Format())
   177  	}
   178  	if err := c.Finalize(ctx); err != nil {
   179  		return nil, nil, err
   180  	}
   181  	if wazevoapi.PrintFinalizedMachineCode && wazevoapi.PrintEnabledIndex(ctx) {
   182  		fmt.Printf("[[[after finalize for %s]]]%s\n", wazevoapi.GetCurrentFunctionName(ctx), c.Format())
   183  	}
   184  	if wazevoapi.DeterministicCompilationVerifierEnabled {
   185  		wazevoapi.VerifyOrSetDeterministicCompilationContextValue(ctx, "After Finalization", c.Format())
   186  	}
   187  	return c.buf, c.relocations, nil
   188  }
   189  
   190  // RegAlloc implements Compiler.RegAlloc.
   191  func (c *compiler) RegAlloc() {
   192  	c.mach.RegAlloc()
   193  }
   194  
   195  // Finalize implements Compiler.Finalize.
   196  func (c *compiler) Finalize(ctx context.Context) error {
   197  	c.mach.PostRegAlloc()
   198  	return c.mach.Encode(ctx)
   199  }
   200  
   201  // setCurrentGroupID sets the current instruction group ID.
   202  func (c *compiler) setCurrentGroupID(gid ssa.InstructionGroupID) {
   203  	c.currentGID = gid
   204  }
   205  
   206  // assignVirtualRegisters assigns a virtual register to each ssa.ValueID Valid in the ssa.Builder.
   207  func (c *compiler) assignVirtualRegisters() {
   208  	builder := c.ssaBuilder
   209  	refCounts := builder.ValueRefCounts()
   210  	c.ssaValueRefCounts = refCounts
   211  
   212  	need := len(refCounts)
   213  	if need >= len(c.ssaValueToVRegs) {
   214  		c.ssaValueToVRegs = append(c.ssaValueToVRegs, make([]regalloc.VReg, need+1)...)
   215  	}
   216  	if need >= len(c.ssaValueDefinitions) {
   217  		c.ssaValueDefinitions = append(c.ssaValueDefinitions, make([]SSAValueDefinition, need+1)...)
   218  	}
   219  
   220  	for blk := builder.BlockIteratorReversePostOrderBegin(); blk != nil; blk = builder.BlockIteratorReversePostOrderNext() {
   221  		// First we assign a virtual register to each parameter.
   222  		for i := 0; i < blk.Params(); i++ {
   223  			p := blk.Param(i)
   224  			pid := p.ID()
   225  			typ := p.Type()
   226  			vreg := c.AllocateVReg(typ)
   227  			c.ssaValueToVRegs[pid] = vreg
   228  			c.ssaValueDefinitions[pid] = SSAValueDefinition{BlockParamValue: p, BlkParamVReg: vreg}
   229  			c.ssaTypeOfVRegID[vreg.ID()] = p.Type()
   230  		}
   231  
   232  		// Assigns each value to a virtual register produced by instructions.
   233  		for cur := blk.Root(); cur != nil; cur = cur.Next() {
   234  			r, rs := cur.Returns()
   235  			var N int
   236  			if r.Valid() {
   237  				id := r.ID()
   238  				ssaTyp := r.Type()
   239  				typ := r.Type()
   240  				vReg := c.AllocateVReg(typ)
   241  				c.ssaValueToVRegs[id] = vReg
   242  				c.ssaValueDefinitions[id] = SSAValueDefinition{
   243  					Instr:    cur,
   244  					N:        0,
   245  					RefCount: refCounts[id],
   246  				}
   247  				c.ssaTypeOfVRegID[vReg.ID()] = ssaTyp
   248  				N++
   249  			}
   250  			for _, r := range rs {
   251  				id := r.ID()
   252  				ssaTyp := r.Type()
   253  				vReg := c.AllocateVReg(ssaTyp)
   254  				c.ssaValueToVRegs[id] = vReg
   255  				c.ssaValueDefinitions[id] = SSAValueDefinition{
   256  					Instr:    cur,
   257  					N:        N,
   258  					RefCount: refCounts[id],
   259  				}
   260  				c.ssaTypeOfVRegID[vReg.ID()] = ssaTyp
   261  				N++
   262  			}
   263  		}
   264  	}
   265  
   266  	for i, retBlk := 0, builder.ReturnBlock(); i < retBlk.Params(); i++ {
   267  		typ := retBlk.Param(i).Type()
   268  		vReg := c.AllocateVReg(typ)
   269  		c.returnVRegs = append(c.returnVRegs, vReg)
   270  		c.ssaTypeOfVRegID[vReg.ID()] = typ
   271  	}
   272  }
   273  
   274  // AllocateVReg implements Compiler.AllocateVReg.
   275  func (c *compiler) AllocateVReg(typ ssa.Type) regalloc.VReg {
   276  	regType := regalloc.RegTypeOf(typ)
   277  	r := regalloc.VReg(c.nextVRegID).SetRegType(regType)
   278  
   279  	id := r.ID()
   280  	if int(id) >= len(c.ssaTypeOfVRegID) {
   281  		c.ssaTypeOfVRegID = append(c.ssaTypeOfVRegID, make([]ssa.Type, id+1)...)
   282  	}
   283  	c.ssaTypeOfVRegID[id] = typ
   284  	c.nextVRegID++
   285  	return r
   286  }
   287  
   288  // Init implements Compiler.Init.
   289  func (c *compiler) Init() {
   290  	c.currentGID = 0
   291  	c.nextVRegID = regalloc.VRegIDNonReservedBegin
   292  	c.returnVRegs = c.returnVRegs[:0]
   293  	c.mach.Reset()
   294  	c.varEdges = c.varEdges[:0]
   295  	c.constEdges = c.constEdges[:0]
   296  	c.buf = c.buf[:0]
   297  	c.sourceOffsets = c.sourceOffsets[:0]
   298  	c.relocations = c.relocations[:0]
   299  }
   300  
   301  // ValueDefinition implements Compiler.ValueDefinition.
   302  func (c *compiler) ValueDefinition(value ssa.Value) *SSAValueDefinition {
   303  	return &c.ssaValueDefinitions[value.ID()]
   304  }
   305  
   306  // VRegOf implements Compiler.VRegOf.
   307  func (c *compiler) VRegOf(value ssa.Value) regalloc.VReg {
   308  	return c.ssaValueToVRegs[value.ID()]
   309  }
   310  
   311  // Format implements Compiler.Format.
   312  func (c *compiler) Format() string {
   313  	return c.mach.Format()
   314  }
   315  
   316  // TypeOf implements Compiler.Format.
   317  func (c *compiler) TypeOf(v regalloc.VReg) ssa.Type {
   318  	return c.ssaTypeOfVRegID[v.ID()]
   319  }
   320  
   321  // MatchInstr implements Compiler.MatchInstr.
   322  func (c *compiler) MatchInstr(def *SSAValueDefinition, opcode ssa.Opcode) bool {
   323  	instr := def.Instr
   324  	return def.IsFromInstr() &&
   325  		instr.Opcode() == opcode &&
   326  		instr.GroupID() == c.currentGID &&
   327  		def.RefCount < 2
   328  }
   329  
   330  // MatchInstrOneOf implements Compiler.MatchInstrOneOf.
   331  func (c *compiler) MatchInstrOneOf(def *SSAValueDefinition, opcodes []ssa.Opcode) ssa.Opcode {
   332  	instr := def.Instr
   333  	if !def.IsFromInstr() {
   334  		return ssa.OpcodeInvalid
   335  	}
   336  
   337  	if instr.GroupID() != c.currentGID {
   338  		return ssa.OpcodeInvalid
   339  	}
   340  
   341  	if def.RefCount >= 2 {
   342  		return ssa.OpcodeInvalid
   343  	}
   344  
   345  	opcode := instr.Opcode()
   346  	for _, op := range opcodes {
   347  		if opcode == op {
   348  			return opcode
   349  		}
   350  	}
   351  	return ssa.OpcodeInvalid
   352  }
   353  
   354  // SSABuilder implements Compiler .SSABuilder.
   355  func (c *compiler) SSABuilder() ssa.Builder {
   356  	return c.ssaBuilder
   357  }
   358  
   359  // AddSourceOffsetInfo implements Compiler.AddSourceOffsetInfo.
   360  func (c *compiler) AddSourceOffsetInfo(executableOffset int64, sourceOffset ssa.SourceOffset) {
   361  	c.sourceOffsets = append(c.sourceOffsets, SourceOffsetInfo{
   362  		SourceOffset:     sourceOffset,
   363  		ExecutableOffset: executableOffset,
   364  	})
   365  }
   366  
   367  // SourceOffsetInfo implements Compiler.SourceOffsetInfo.
   368  func (c *compiler) SourceOffsetInfo() []SourceOffsetInfo {
   369  	return c.sourceOffsets
   370  }
   371  
   372  // AddRelocationInfo implements Compiler.AddRelocationInfo.
   373  func (c *compiler) AddRelocationInfo(funcRef ssa.FuncRef) {
   374  	c.relocations = append(c.relocations, RelocationInfo{
   375  		Offset:  int64(len(c.buf)),
   376  		FuncRef: funcRef,
   377  	})
   378  }
   379  
   380  // Emit8Bytes implements Compiler.Emit8Bytes.
   381  func (c *compiler) Emit8Bytes(b uint64) {
   382  	c.buf = append(c.buf, byte(b), byte(b>>8), byte(b>>16), byte(b>>24), byte(b>>32), byte(b>>40), byte(b>>48), byte(b>>56))
   383  }
   384  
   385  // Emit4Bytes implements Compiler.Emit4Bytes.
   386  func (c *compiler) Emit4Bytes(b uint32) {
   387  	c.buf = append(c.buf, byte(b), byte(b>>8), byte(b>>16), byte(b>>24))
   388  }
   389  
   390  // EmitByte implements Compiler.EmitByte.
   391  func (c *compiler) EmitByte(b byte) {
   392  	c.buf = append(c.buf, b)
   393  }
   394  
   395  // Buf implements Compiler.Buf.
   396  func (c *compiler) Buf() []byte {
   397  	return c.buf
   398  }
   399  
   400  // BufPtr implements Compiler.BufPtr.
   401  func (c *compiler) BufPtr() *[]byte {
   402  	return &c.buf
   403  }
   404  
   405  func (c *compiler) GetFunctionABI(sig *ssa.Signature) *FunctionABI {
   406  	if int(sig.ID) >= len(c.abis) {
   407  		c.abis = append(c.abis, make([]FunctionABI, int(sig.ID)+1)...)
   408  	}
   409  
   410  	abi := &c.abis[sig.ID]
   411  	if abi.Initialized {
   412  		return abi
   413  	}
   414  
   415  	abi.Init(sig, c.argResultInts, c.argResultFloats)
   416  	return abi
   417  }