github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/engine/wazevo/backend/compiler_lower.go (about)

     1  package backend
     2  
     3  import (
     4  	"github.com/tetratelabs/wazero/internal/engine/wazevo/backend/regalloc"
     5  	"github.com/tetratelabs/wazero/internal/engine/wazevo/ssa"
     6  )
     7  
     8  // Lower implements Compiler.Lower.
     9  func (c *compiler) Lower() {
    10  	c.assignVirtualRegisters()
    11  	c.mach.SetCurrentABI(c.GetFunctionABI(c.ssaBuilder.Signature()))
    12  	c.mach.ExecutableContext().StartLoweringFunction(c.ssaBuilder.BlockIDMax())
    13  	c.lowerBlocks()
    14  }
    15  
    16  // lowerBlocks lowers each block in the ssa.Builder.
    17  func (c *compiler) lowerBlocks() {
    18  	builder := c.ssaBuilder
    19  	for blk := builder.BlockIteratorReversePostOrderBegin(); blk != nil; blk = builder.BlockIteratorReversePostOrderNext() {
    20  		c.lowerBlock(blk)
    21  	}
    22  
    23  	ectx := c.mach.ExecutableContext()
    24  	// After lowering all blocks, we need to link adjacent blocks to layout one single instruction list.
    25  	var prev ssa.BasicBlock
    26  	for next := builder.BlockIteratorReversePostOrderBegin(); next != nil; next = builder.BlockIteratorReversePostOrderNext() {
    27  		if prev != nil {
    28  			ectx.LinkAdjacentBlocks(prev, next)
    29  		}
    30  		prev = next
    31  	}
    32  }
    33  
    34  func (c *compiler) lowerBlock(blk ssa.BasicBlock) {
    35  	mach := c.mach
    36  	ectx := mach.ExecutableContext()
    37  	ectx.StartBlock(blk)
    38  
    39  	// We traverse the instructions in reverse order because we might want to lower multiple
    40  	// instructions together.
    41  	cur := blk.Tail()
    42  
    43  	// First gather the branching instructions at the end of the blocks.
    44  	var br0, br1 *ssa.Instruction
    45  	if cur.IsBranching() {
    46  		br0 = cur
    47  		cur = cur.Prev()
    48  		if cur != nil && cur.IsBranching() {
    49  			br1 = cur
    50  			cur = cur.Prev()
    51  		}
    52  	}
    53  
    54  	if br0 != nil {
    55  		c.lowerBranches(br0, br1)
    56  	}
    57  
    58  	if br1 != nil && br0 == nil {
    59  		panic("BUG? when a block has conditional branch but doesn't end with an unconditional branch?")
    60  	}
    61  
    62  	// Now start lowering the non-branching instructions.
    63  	for ; cur != nil; cur = cur.Prev() {
    64  		c.setCurrentGroupID(cur.GroupID())
    65  		if cur.Lowered() {
    66  			continue
    67  		}
    68  
    69  		switch cur.Opcode() {
    70  		case ssa.OpcodeReturn:
    71  			rets := cur.ReturnVals()
    72  			if len(rets) > 0 {
    73  				c.mach.LowerReturns(rets)
    74  			}
    75  			c.mach.InsertReturn()
    76  		default:
    77  			mach.LowerInstr(cur)
    78  		}
    79  		ectx.FlushPendingInstructions()
    80  	}
    81  
    82  	// Finally, if this is the entry block, we have to insert copies of arguments from the real location to the VReg.
    83  	if blk.EntryBlock() {
    84  		c.lowerFunctionArguments(blk)
    85  	}
    86  
    87  	ectx.EndBlock()
    88  }
    89  
    90  // lowerBranches is called right after StartBlock and before any LowerInstr call if
    91  // there are branches to the given block. br0 is the very end of the block and b1 is the before the br0 if it exists.
    92  // At least br0 is not nil, but br1 can be nil if there's no branching before br0.
    93  //
    94  // See ssa.Instruction IsBranching, and the comment on ssa.BasicBlock.
    95  func (c *compiler) lowerBranches(br0, br1 *ssa.Instruction) {
    96  	ectx := c.mach.ExecutableContext()
    97  
    98  	c.setCurrentGroupID(br0.GroupID())
    99  	c.mach.LowerSingleBranch(br0)
   100  	ectx.FlushPendingInstructions()
   101  	if br1 != nil {
   102  		c.setCurrentGroupID(br1.GroupID())
   103  		c.mach.LowerConditionalBranch(br1)
   104  		ectx.FlushPendingInstructions()
   105  	}
   106  
   107  	if br0.Opcode() == ssa.OpcodeJump {
   108  		_, args, target := br0.BranchData()
   109  		argExists := len(args) != 0
   110  		if argExists && br1 != nil {
   111  			panic("BUG: critical edge split failed")
   112  		}
   113  		if argExists && target.ReturnBlock() {
   114  			if len(args) > 0 {
   115  				c.mach.LowerReturns(args)
   116  			}
   117  		} else if argExists {
   118  			c.lowerBlockArguments(args, target)
   119  		}
   120  	}
   121  	ectx.FlushPendingInstructions()
   122  }
   123  
   124  func (c *compiler) lowerFunctionArguments(entry ssa.BasicBlock) {
   125  	ectx := c.mach.ExecutableContext()
   126  
   127  	c.tmpVals = c.tmpVals[:0]
   128  	for i := 0; i < entry.Params(); i++ {
   129  		p := entry.Param(i)
   130  		if c.ssaValueRefCounts[p.ID()] > 0 {
   131  			c.tmpVals = append(c.tmpVals, p)
   132  		} else {
   133  			// If the argument is not used, we can just pass an invalid value.
   134  			c.tmpVals = append(c.tmpVals, ssa.ValueInvalid)
   135  		}
   136  	}
   137  	c.mach.LowerParams(c.tmpVals)
   138  	ectx.FlushPendingInstructions()
   139  }
   140  
   141  // lowerBlockArguments lowers how to pass arguments to the given successor block.
   142  func (c *compiler) lowerBlockArguments(args []ssa.Value, succ ssa.BasicBlock) {
   143  	if len(args) != succ.Params() {
   144  		panic("BUG: mismatched number of arguments")
   145  	}
   146  
   147  	c.varEdges = c.varEdges[:0]
   148  	c.varEdgeTypes = c.varEdgeTypes[:0]
   149  	c.constEdges = c.constEdges[:0]
   150  	for i := 0; i < len(args); i++ {
   151  		dst := succ.Param(i)
   152  		src := args[i]
   153  
   154  		dstReg := c.VRegOf(dst)
   155  		srcDef := c.ssaValueDefinitions[src.ID()]
   156  		if srcDef.IsFromInstr() && srcDef.Instr.Constant() {
   157  			c.constEdges = append(c.constEdges, struct {
   158  				cInst *ssa.Instruction
   159  				dst   regalloc.VReg
   160  			}{cInst: srcDef.Instr, dst: dstReg})
   161  		} else {
   162  			srcReg := c.VRegOf(src)
   163  			// Even when the src=dst, insert the move so that we can keep such registers keep-alive.
   164  			c.varEdges = append(c.varEdges, [2]regalloc.VReg{srcReg, dstReg})
   165  			c.varEdgeTypes = append(c.varEdgeTypes, src.Type())
   166  		}
   167  	}
   168  
   169  	// Check if there's an overlap among the dsts and srcs in varEdges.
   170  	c.vRegIDs = c.vRegIDs[:0]
   171  	for _, edge := range c.varEdges {
   172  		src := edge[0].ID()
   173  		if int(src) >= len(c.vRegSet) {
   174  			c.vRegSet = append(c.vRegSet, make([]bool, src+1)...)
   175  		}
   176  		c.vRegSet[src] = true
   177  		c.vRegIDs = append(c.vRegIDs, src)
   178  	}
   179  	separated := true
   180  	for _, edge := range c.varEdges {
   181  		dst := edge[1].ID()
   182  		if int(dst) >= len(c.vRegSet) {
   183  			c.vRegSet = append(c.vRegSet, make([]bool, dst+1)...)
   184  		} else {
   185  			if c.vRegSet[dst] {
   186  				separated = false
   187  				break
   188  			}
   189  		}
   190  	}
   191  	for _, id := range c.vRegIDs {
   192  		c.vRegSet[id] = false // reset for the next use.
   193  	}
   194  
   195  	if separated {
   196  		// If there's no overlap, we can simply move the source to destination.
   197  		for i, edge := range c.varEdges {
   198  			src, dst := edge[0], edge[1]
   199  			c.mach.InsertMove(dst, src, c.varEdgeTypes[i])
   200  		}
   201  	} else {
   202  		// Otherwise, we allocate a temporary registers and move the source to the temporary register,
   203  		//
   204  		// First move all of them to temporary registers.
   205  		c.tempRegs = c.tempRegs[:0]
   206  		for i, edge := range c.varEdges {
   207  			src := edge[0]
   208  			typ := c.varEdgeTypes[i]
   209  			temp := c.AllocateVReg(typ)
   210  			c.tempRegs = append(c.tempRegs, temp)
   211  			c.mach.InsertMove(temp, src, typ)
   212  		}
   213  		// Then move the temporary registers to the destination.
   214  		for i, edge := range c.varEdges {
   215  			temp := c.tempRegs[i]
   216  			dst := edge[1]
   217  			c.mach.InsertMove(dst, temp, c.varEdgeTypes[i])
   218  		}
   219  	}
   220  
   221  	// Finally, move the constants.
   222  	for _, edge := range c.constEdges {
   223  		cInst, dst := edge.cInst, edge.dst
   224  		c.mach.InsertLoadConstantBlockArg(cInst, dst)
   225  	}
   226  }