github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/wazevo/backend/compiler_lower.go (about)

     1  package backend
     2  
     3  import (
     4  	"github.com/bananabytelabs/wazero/internal/engine/wazevo/backend/regalloc"
     5  	"github.com/bananabytelabs/wazero/internal/engine/wazevo/ssa"
     6  )
     7  
     8  // Lower implements Compiler.Lower.
     9  func (c *compiler) Lower() {
    10  	c.assignVirtualRegisters()
    11  	c.mach.InitializeABI(c.ssaBuilder.Signature())
    12  	c.mach.ExecutableContext().StartLoweringFunction(c.ssaBuilder.BlockIDMax())
    13  	c.lowerBlocks()
    14  }
    15  
    16  // lowerBlocks lowers each block in the ssa.Builder.
    17  func (c *compiler) lowerBlocks() {
    18  	builder := c.ssaBuilder
    19  	for blk := builder.BlockIteratorReversePostOrderBegin(); blk != nil; blk = builder.BlockIteratorReversePostOrderNext() {
    20  		c.lowerBlock(blk)
    21  	}
    22  
    23  	ectx := c.mach.ExecutableContext()
    24  	// After lowering all blocks, we need to link adjacent blocks to layout one single instruction list.
    25  	var prev ssa.BasicBlock
    26  	for next := builder.BlockIteratorReversePostOrderBegin(); next != nil; next = builder.BlockIteratorReversePostOrderNext() {
    27  		if prev != nil {
    28  			ectx.LinkAdjacentBlocks(prev, next)
    29  		}
    30  		prev = next
    31  	}
    32  }
    33  
    34  func (c *compiler) lowerBlock(blk ssa.BasicBlock) {
    35  	mach := c.mach
    36  	mach.StartBlock(blk)
    37  
    38  	// We traverse the instructions in reverse order because we might want to lower multiple
    39  	// instructions together.
    40  	cur := blk.Tail()
    41  
    42  	// First gather the branching instructions at the end of the blocks.
    43  	var br0, br1 *ssa.Instruction
    44  	if cur.IsBranching() {
    45  		br0 = cur
    46  		cur = cur.Prev()
    47  		if cur != nil && cur.IsBranching() {
    48  			br1 = cur
    49  			cur = cur.Prev()
    50  		}
    51  	}
    52  
    53  	if br0 != nil {
    54  		c.lowerBranches(br0, br1)
    55  	}
    56  
    57  	if br1 != nil && br0 == nil {
    58  		panic("BUG? when a block has conditional branch but doesn't end with an unconditional branch?")
    59  	}
    60  
    61  	// Now start lowering the non-branching instructions.
    62  	ectx := mach.ExecutableContext()
    63  	for ; cur != nil; cur = cur.Prev() {
    64  		c.setCurrentGroupID(cur.GroupID())
    65  		if cur.Lowered() {
    66  			continue
    67  		}
    68  
    69  		switch cur.Opcode() {
    70  		case ssa.OpcodeReturn:
    71  			c.lowerFunctionReturns(cur.ReturnVals())
    72  			c.mach.InsertReturn()
    73  		default:
    74  			mach.LowerInstr(cur)
    75  		}
    76  		ectx.FlushPendingInstructions()
    77  	}
    78  
    79  	// Finally, if this is the entry block, we have to insert copies of arguments from the real location to the VReg.
    80  	if blk.EntryBlock() {
    81  		c.lowerFunctionArguments(blk)
    82  	}
    83  
    84  	ectx.EndBlock()
    85  }
    86  
    87  // lowerBranches is called right after StartBlock and before any LowerInstr call if
    88  // there are branches to the given block. br0 is the very end of the block and b1 is the before the br0 if it exists.
    89  // At least br0 is not nil, but br1 can be nil if there's no branching before br0.
    90  //
    91  // See ssa.Instruction IsBranching, and the comment on ssa.BasicBlock.
    92  func (c *compiler) lowerBranches(br0, br1 *ssa.Instruction) {
    93  	ectx := c.mach.ExecutableContext()
    94  
    95  	c.setCurrentGroupID(br0.GroupID())
    96  	c.mach.LowerSingleBranch(br0)
    97  	ectx.FlushPendingInstructions()
    98  	if br1 != nil {
    99  		c.setCurrentGroupID(br1.GroupID())
   100  		c.mach.LowerConditionalBranch(br1)
   101  		ectx.FlushPendingInstructions()
   102  	}
   103  
   104  	if br0.Opcode() == ssa.OpcodeJump {
   105  		_, args, target := br0.BranchData()
   106  		argExists := len(args) != 0
   107  		if argExists && br1 != nil {
   108  			panic("BUG: critical edge split failed")
   109  		}
   110  		if argExists && target.ReturnBlock() {
   111  			c.lowerFunctionReturns(args)
   112  		} else if argExists {
   113  			c.lowerBlockArguments(args, target)
   114  		}
   115  	}
   116  	ectx.FlushPendingInstructions()
   117  }
   118  
   119  func (c *compiler) lowerFunctionArguments(entry ssa.BasicBlock) {
   120  	ectx := c.mach.ExecutableContext()
   121  
   122  	c.tmpVals = c.tmpVals[:0]
   123  	for i := 0; i < entry.Params(); i++ {
   124  		p := entry.Param(i)
   125  		if c.ssaValueRefCounts[p.ID()] > 0 {
   126  			c.tmpVals = append(c.tmpVals, p)
   127  		} else {
   128  			// If the argument is not used, we can just pass an invalid value.
   129  			c.tmpVals = append(c.tmpVals, ssa.ValueInvalid)
   130  		}
   131  	}
   132  	c.mach.ABI().CalleeGenFunctionArgsToVRegs(c.tmpVals)
   133  	ectx.FlushPendingInstructions()
   134  }
   135  
   136  func (c *compiler) lowerFunctionReturns(returns []ssa.Value) {
   137  	c.mach.ABI().CalleeGenVRegsToFunctionReturns(returns)
   138  }
   139  
   140  // lowerBlockArguments lowers how to pass arguments to the given successor block.
   141  func (c *compiler) lowerBlockArguments(args []ssa.Value, succ ssa.BasicBlock) {
   142  	if len(args) != succ.Params() {
   143  		panic("BUG: mismatched number of arguments")
   144  	}
   145  
   146  	c.varEdges = c.varEdges[:0]
   147  	c.varEdgeTypes = c.varEdgeTypes[:0]
   148  	c.constEdges = c.constEdges[:0]
   149  	for i := 0; i < len(args); i++ {
   150  		dst := succ.Param(i)
   151  		src := args[i]
   152  
   153  		dstReg := c.VRegOf(dst)
   154  		srcDef := c.ssaValueDefinitions[src.ID()]
   155  		if srcDef.IsFromInstr() && srcDef.Instr.Constant() {
   156  			c.constEdges = append(c.constEdges, struct {
   157  				cInst *ssa.Instruction
   158  				dst   regalloc.VReg
   159  			}{cInst: srcDef.Instr, dst: dstReg})
   160  		} else {
   161  			srcReg := c.VRegOf(src)
   162  			// Even when the src=dst, insert the move so that we can keep such registers keep-alive.
   163  			c.varEdges = append(c.varEdges, [2]regalloc.VReg{srcReg, dstReg})
   164  			c.varEdgeTypes = append(c.varEdgeTypes, src.Type())
   165  		}
   166  	}
   167  
   168  	// Check if there's an overlap among the dsts and srcs in varEdges.
   169  	c.vRegIDs = c.vRegIDs[:0]
   170  	for _, edge := range c.varEdges {
   171  		src := edge[0].ID()
   172  		if int(src) >= len(c.vRegSet) {
   173  			c.vRegSet = append(c.vRegSet, make([]bool, src+1)...)
   174  		}
   175  		c.vRegSet[src] = true
   176  		c.vRegIDs = append(c.vRegIDs, src)
   177  	}
   178  	separated := true
   179  	for _, edge := range c.varEdges {
   180  		dst := edge[1].ID()
   181  		if int(dst) >= len(c.vRegSet) {
   182  			c.vRegSet = append(c.vRegSet, make([]bool, dst+1)...)
   183  		} else {
   184  			if c.vRegSet[dst] {
   185  				separated = false
   186  				break
   187  			}
   188  		}
   189  	}
   190  	for _, id := range c.vRegIDs {
   191  		c.vRegSet[id] = false // reset for the next use.
   192  	}
   193  
   194  	if separated {
   195  		// If there's no overlap, we can simply move the source to destination.
   196  		for i, edge := range c.varEdges {
   197  			src, dst := edge[0], edge[1]
   198  			c.mach.InsertMove(dst, src, c.varEdgeTypes[i])
   199  		}
   200  	} else {
   201  		// Otherwise, we allocate a temporary registers and move the source to the temporary register,
   202  		//
   203  		// First move all of them to temporary registers.
   204  		c.tempRegs = c.tempRegs[:0]
   205  		for i, edge := range c.varEdges {
   206  			src := edge[0]
   207  			typ := c.varEdgeTypes[i]
   208  			temp := c.AllocateVReg(typ)
   209  			c.tempRegs = append(c.tempRegs, temp)
   210  			c.mach.InsertMove(temp, src, typ)
   211  		}
   212  		// Then move the temporary registers to the destination.
   213  		for i, edge := range c.varEdges {
   214  			temp := c.tempRegs[i]
   215  			dst := edge[1]
   216  			c.mach.InsertMove(dst, temp, c.varEdgeTypes[i])
   217  		}
   218  	}
   219  
   220  	// Finally, move the constants.
   221  	for _, edge := range c.constEdges {
   222  		cInst, dst := edge.cInst, edge.dst
   223  		c.mach.InsertLoadConstant(cInst, dst)
   224  	}
   225  }