github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/internal/engine/wazevo/backend/isa/arm64/machine.go (about)

     1  package arm64
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"strings"
     7  
     8  	"github.com/wasilibs/wazerox/internal/engine/wazevo/backend"
     9  	"github.com/wasilibs/wazerox/internal/engine/wazevo/backend/regalloc"
    10  	"github.com/wasilibs/wazerox/internal/engine/wazevo/ssa"
    11  	"github.com/wasilibs/wazerox/internal/engine/wazevo/wazevoapi"
    12  )
    13  
    14  type (
    15  	// machine implements backend.Machine.
    16  	machine struct {
    17  		compiler      backend.Compiler
    18  		currentABI    *abiImpl
    19  		currentSSABlk ssa.BasicBlock
    20  		// abis maps ssa.SignatureID to the ABI implementation.
    21  		abis      []abiImpl
    22  		instrPool wazevoapi.Pool[instruction]
    23  		// rootInstr is the root instruction of the currently-compiled function.
    24  		rootInstr *instruction
    25  		// perBlockHead and perBlockEnd are the head and tail of the instruction list per currently-compiled ssa.BasicBlock.
    26  		perBlockHead, perBlockEnd *instruction
    27  		// pendingInstructions are the instructions which are not yet emitted into the instruction list.
    28  		pendingInstructions []*instruction
    29  		regAllocFn          regAllocFunctionImpl
    30  		nextLabel           label
    31  
    32  		// ssaBlockIDToLabels maps an SSA block ID to the label.
    33  		ssaBlockIDToLabels []label
    34  		// labelToInstructions maps a label to the instructions of the region which the label represents.
    35  		labelPositions     map[label]*labelPosition
    36  		orderedBlockLabels []*labelPosition
    37  		labelPositionPool  wazevoapi.Pool[labelPosition]
    38  
    39  		// addendsWorkQueue is used during address lowering, defined here for reuse.
    40  		addendsWorkQueue queue[ssa.Value]
    41  		addends32        queue[addend32]
    42  		// addends64 is used during address lowering, defined here for reuse.
    43  		addends64              queue[regalloc.VReg]
    44  		unresolvedAddressModes []*instruction
    45  
    46  		// condBrRelocs holds the conditional branches which need offset relocation.
    47  		condBrRelocs []condBrReloc
    48  
    49  		// spillSlotSize is the size of the stack slot in bytes used for spilling registers.
    50  		// During the execution of the function, the stack looks like:
    51  		//
    52  		//
    53  		//            (high address)
    54  		//          +-----------------+
    55  		//          |     .......     |
    56  		//          |      ret Y      |
    57  		//          |     .......     |
    58  		//          |      ret 0      |
    59  		//          |      arg X      |
    60  		//          |     .......     |
    61  		//          |      arg 1      |
    62  		//          |      arg 0      |
    63  		//          |      xxxxx      |
    64  		//          |   ReturnAddress |
    65  		//          +-----------------+   <<-|
    66  		//          |   ...........   |      |
    67  		//          |   spill slot M  |      | <--- spillSlotSize
    68  		//          |   ............  |      |
    69  		//          |   spill slot 2  |      |
    70  		//          |   spill slot 1  |   <<-+
    71  		//          |   clobbered N   |
    72  		//          |   ...........   |
    73  		//          |   clobbered 1   |
    74  		//          |   clobbered 0   |
    75  		//   SP---> +-----------------+
    76  		//             (low address)
    77  		//
    78  		// and it represents the size of the space between FP and the first spilled slot. This must be a multiple of 16.
    79  		// Also note that this is only known after register allocation.
    80  		spillSlotSize int64
    81  		spillSlots    map[regalloc.VRegID]int64 // regalloc.VRegID to offset.
    82  		// clobberedRegs holds real-register backed VRegs saved at the function prologue, and restored at the epilogue.
    83  		clobberedRegs []regalloc.VReg
    84  
    85  		maxRequiredStackSizeForCalls int64
    86  		stackBoundsCheckDisabled     bool
    87  
    88  		regAllocStarted bool
    89  	}
    90  
    91  	addend32 struct {
    92  		r   regalloc.VReg
    93  		ext extendOp
    94  	}
    95  
    96  	// label represents a position in the generated code which is either
    97  	// a real instruction or the constant pool (e.g. jump tables).
    98  	//
    99  	// This is exactly the same as the traditional "label" in assembly code.
   100  	label uint32
   101  
   102  	// labelPosition represents the regions of the generated code which the label represents.
   103  	labelPosition struct {
   104  		l            label
   105  		begin, end   *instruction
   106  		binarySize   int64
   107  		binaryOffset int64
   108  	}
   109  
   110  	condBrReloc struct {
   111  		cbr *instruction
   112  		// currentLabelPos is the labelPosition within which condBr is defined.
   113  		currentLabelPos *labelPosition
   114  		// Next block's labelPosition.
   115  		nextLabel label
   116  		offset    int64
   117  	}
   118  )
   119  
   120  const (
   121  	invalidLabel = 0
   122  	returnLabel  = math.MaxUint32
   123  )
   124  
   125  // NewBackend returns a new backend for arm64.
   126  func NewBackend() backend.Machine {
   127  	m := &machine{
   128  		instrPool:         wazevoapi.NewPool[instruction](resetInstruction),
   129  		labelPositionPool: wazevoapi.NewPool[labelPosition](resetLabelPosition),
   130  		labelPositions:    make(map[label]*labelPosition),
   131  		spillSlots:        make(map[regalloc.VRegID]int64),
   132  		nextLabel:         invalidLabel,
   133  	}
   134  	m.regAllocFn.m = m
   135  	m.regAllocFn.labelToRegAllocBlockIndex = make(map[label]int)
   136  	return m
   137  }
   138  
   139  // Reset implements backend.Machine.
   140  func (m *machine) Reset() {
   141  	m.regAllocStarted = false
   142  	m.instrPool.Reset()
   143  	m.labelPositionPool.Reset()
   144  	m.currentSSABlk = nil
   145  	for l := label(0); l <= m.nextLabel; l++ {
   146  		delete(m.labelPositions, l)
   147  	}
   148  	m.pendingInstructions = m.pendingInstructions[:0]
   149  	m.clobberedRegs = m.clobberedRegs[:0]
   150  	for key := range m.spillSlots {
   151  		m.clobberedRegs = append(m.clobberedRegs, regalloc.VReg(key))
   152  	}
   153  	for _, key := range m.clobberedRegs {
   154  		delete(m.spillSlots, regalloc.VRegID(key))
   155  	}
   156  	m.clobberedRegs = m.clobberedRegs[:0]
   157  	m.orderedBlockLabels = m.orderedBlockLabels[:0]
   158  	m.regAllocFn.reset()
   159  	m.spillSlotSize = 0
   160  	m.unresolvedAddressModes = m.unresolvedAddressModes[:0]
   161  	m.rootInstr = nil
   162  	m.ssaBlockIDToLabels = m.ssaBlockIDToLabels[:0]
   163  	m.perBlockHead, m.perBlockEnd = nil, nil
   164  	m.maxRequiredStackSizeForCalls = 0
   165  	m.nextLabel = invalidLabel
   166  }
   167  
   168  // InitializeABI implements backend.Machine InitializeABI.
   169  func (m *machine) InitializeABI(sig *ssa.Signature) {
   170  	m.currentABI = m.getOrCreateABIImpl(sig)
   171  }
   172  
   173  // DisableStackCheck implements backend.Machine DisableStackCheck.
   174  func (m *machine) DisableStackCheck() {
   175  	m.stackBoundsCheckDisabled = true
   176  }
   177  
   178  // ABI implements backend.Machine.
   179  func (m *machine) ABI() backend.FunctionABI {
   180  	return m.currentABI
   181  }
   182  
   183  // allocateLabel allocates an unused label.
   184  func (m *machine) allocateLabel() label {
   185  	m.nextLabel++
   186  	return m.nextLabel
   187  }
   188  
   189  // SetCompiler implements backend.Machine.
   190  func (m *machine) SetCompiler(ctx backend.Compiler) {
   191  	m.compiler = ctx
   192  }
   193  
   194  // StartLoweringFunction implements backend.Machine.
   195  func (m *machine) StartLoweringFunction(max ssa.BasicBlockID) {
   196  	imax := int(max)
   197  	if len(m.ssaBlockIDToLabels) <= imax {
   198  		// Eagerly allocate labels for the blocks since the underlying slice will be used for the next iteration.
   199  		m.ssaBlockIDToLabels = append(m.ssaBlockIDToLabels, make([]label, imax+1)...)
   200  	}
   201  }
   202  
   203  // EndLoweringFunction implements backend.Machine.
   204  func (m *machine) EndLoweringFunction() {}
   205  
   206  // StartBlock implements backend.Machine.
   207  func (m *machine) StartBlock(blk ssa.BasicBlock) {
   208  	m.currentSSABlk = blk
   209  
   210  	l := m.ssaBlockIDToLabels[m.currentSSABlk.ID()]
   211  	if l == invalidLabel {
   212  		l = m.allocateLabel()
   213  		m.ssaBlockIDToLabels[blk.ID()] = l
   214  	}
   215  
   216  	end := m.allocateNop()
   217  	m.perBlockHead, m.perBlockEnd = end, end
   218  
   219  	labelPos, ok := m.labelPositions[l]
   220  	if !ok {
   221  		labelPos = m.allocateLabelPosition(l)
   222  		m.labelPositions[l] = labelPos
   223  	}
   224  	m.orderedBlockLabels = append(m.orderedBlockLabels, labelPos)
   225  	labelPos.begin, labelPos.end = end, end
   226  	m.regAllocFn.addBlock(blk, l, labelPos)
   227  }
   228  
   229  // EndBlock implements backend.Machine.
   230  func (m *machine) EndBlock() {
   231  	// Insert nop0 as the head of the block for convenience to simplify the logic of inserting instructions.
   232  	m.insertAtPerBlockHead(m.allocateNop())
   233  
   234  	l := m.ssaBlockIDToLabels[m.currentSSABlk.ID()]
   235  	m.labelPositions[l].begin = m.perBlockHead
   236  
   237  	if m.currentSSABlk.EntryBlock() {
   238  		m.rootInstr = m.perBlockHead
   239  	}
   240  }
   241  
   242  func (m *machine) insert(i *instruction) {
   243  	m.pendingInstructions = append(m.pendingInstructions, i)
   244  }
   245  
   246  func (m *machine) insertBrTargetLabel() label {
   247  	nop, l := m.allocateBrTarget()
   248  	m.insert(nop)
   249  	return l
   250  }
   251  
   252  func (m *machine) allocateBrTarget() (nop *instruction, l label) {
   253  	l = m.allocateLabel()
   254  	nop = m.allocateInstr()
   255  	nop.asNop0WithLabel(l)
   256  	pos := m.allocateLabelPosition(l)
   257  	pos.begin, pos.end = nop, nop
   258  	m.labelPositions[l] = pos
   259  	return
   260  }
   261  
   262  func (m *machine) allocateLabelPosition(la label) *labelPosition {
   263  	l := m.labelPositionPool.Allocate()
   264  	l.l = la
   265  	return l
   266  }
   267  
   268  func resetLabelPosition(l *labelPosition) {
   269  	*l = labelPosition{}
   270  }
   271  
   272  // FlushPendingInstructions implements backend.Machine.
   273  func (m *machine) FlushPendingInstructions() {
   274  	l := len(m.pendingInstructions)
   275  	if l == 0 {
   276  		return
   277  	}
   278  	for i := l - 1; i >= 0; i-- { // reverse because we lower instructions in reverse order.
   279  		m.insertAtPerBlockHead(m.pendingInstructions[i])
   280  	}
   281  	m.pendingInstructions = m.pendingInstructions[:0]
   282  }
   283  
   284  func (m *machine) insertAtPerBlockHead(i *instruction) {
   285  	if m.perBlockHead == nil {
   286  		m.perBlockHead = i
   287  		m.perBlockEnd = i
   288  		return
   289  	}
   290  	i.next = m.perBlockHead
   291  	m.perBlockHead.prev = i
   292  	m.perBlockHead = i
   293  }
   294  
   295  // String implements backend.Machine.
   296  func (l label) String() string {
   297  	return fmt.Sprintf("L%d", l)
   298  }
   299  
   300  // allocateInstr allocates an instruction.
   301  func (m *machine) allocateInstr() *instruction {
   302  	instr := m.instrPool.Allocate()
   303  	if !m.regAllocStarted {
   304  		instr.addedBeforeRegAlloc = true
   305  	}
   306  	return instr
   307  }
   308  
   309  func resetInstruction(i *instruction) {
   310  	*i = instruction{}
   311  }
   312  
   313  func (m *machine) allocateNop() *instruction {
   314  	instr := m.allocateInstr()
   315  	instr.asNop0()
   316  	return instr
   317  }
   318  
   319  func (m *machine) resolveAddressingMode(arg0offset, ret0offset int64, i *instruction) {
   320  	amode := &i.amode
   321  	switch amode.kind {
   322  	case addressModeKindResultStackSpace:
   323  		amode.imm += ret0offset
   324  	case addressModeKindArgStackSpace:
   325  		amode.imm += arg0offset
   326  	default:
   327  		panic("BUG")
   328  	}
   329  
   330  	var sizeInBits byte
   331  	switch i.kind {
   332  	case store8, uLoad8:
   333  		sizeInBits = 8
   334  	case store16, uLoad16:
   335  		sizeInBits = 16
   336  	case store32, fpuStore32, uLoad32, fpuLoad32:
   337  		sizeInBits = 32
   338  	case store64, fpuStore64, uLoad64, fpuLoad64:
   339  		sizeInBits = 64
   340  	case fpuStore128, fpuLoad128:
   341  		sizeInBits = 128
   342  	default:
   343  		panic("BUG")
   344  	}
   345  
   346  	if offsetFitsInAddressModeKindRegUnsignedImm12(sizeInBits, amode.imm) {
   347  		amode.kind = addressModeKindRegUnsignedImm12
   348  	} else {
   349  		// This case, we load the offset into the temporary register,
   350  		// and then use it as the index register.
   351  		newPrev := m.lowerConstantI64AndInsert(i.prev, tmpRegVReg, amode.imm)
   352  		linkInstr(newPrev, i)
   353  		*amode = addressMode{kind: addressModeKindRegReg, rn: amode.rn, rm: tmpRegVReg, extOp: extendOpUXTX /* indicates rm reg is 64-bit */}
   354  	}
   355  }
   356  
   357  // ResolveRelativeAddresses implements backend.Machine.
   358  func (m *machine) ResolveRelativeAddresses() {
   359  	if len(m.unresolvedAddressModes) > 0 {
   360  		arg0offset, ret0offset := m.arg0OffsetFromSP(), m.ret0OffsetFromSP()
   361  		for _, i := range m.unresolvedAddressModes {
   362  			m.resolveAddressingMode(arg0offset, ret0offset, i)
   363  		}
   364  	}
   365  
   366  	// Reuse the slice to gather the unresolved conditional branches.
   367  	cbrs := m.condBrRelocs[:0]
   368  
   369  	// Next, in order to determine the offsets of relative jumps, we have to calculate the size of each label.
   370  	var offset int64
   371  	for i, pos := range m.orderedBlockLabels {
   372  		pos.binaryOffset = offset
   373  		var size int64
   374  		for cur := pos.begin; ; cur = cur.next {
   375  			switch cur.kind {
   376  			case nop0:
   377  				l := cur.nop0Label()
   378  				if pos, ok := m.labelPositions[l]; ok {
   379  					pos.binaryOffset = offset + size
   380  				}
   381  			case condBr:
   382  				if !cur.condBrOffsetResolved() {
   383  					var nextLabel label
   384  					if i < len(m.orderedBlockLabels)-1 {
   385  						// Note: this is only used when the block ends with fallthrough,
   386  						// therefore can be safely assumed that the next block exists when it's needed.
   387  						nextLabel = m.orderedBlockLabels[i+1].l
   388  					}
   389  					cbrs = append(cbrs, condBrReloc{
   390  						cbr: cur, currentLabelPos: pos, offset: offset + size,
   391  						nextLabel: nextLabel,
   392  					})
   393  				}
   394  			}
   395  			size += cur.size()
   396  			if cur == pos.end {
   397  				break
   398  			}
   399  		}
   400  		pos.binarySize = size
   401  		offset += size
   402  	}
   403  
   404  	// Before resolving any offsets, we need to check if all the conditional branches can be resolved.
   405  	var needRerun bool
   406  	for i := range cbrs {
   407  		reloc := &cbrs[i]
   408  		cbr := reloc.cbr
   409  		offset := reloc.offset
   410  
   411  		target := cbr.condBrLabel()
   412  		offsetOfTarget := m.labelPositions[target].binaryOffset
   413  		diff := offsetOfTarget - offset
   414  		if divided := diff >> 2; divided < minSignedInt19 || divided > maxSignedInt19 {
   415  			// This case the conditional branch is too huge. We place the trampoline instructions at the end of the current block,
   416  			// and jump to it.
   417  			m.insertConditionalJumpTrampoline(cbr, reloc.currentLabelPos, reloc.nextLabel)
   418  			// Then, we need to recall this function to fix up the label offsets
   419  			// as they have changed after the trampoline is inserted.
   420  			needRerun = true
   421  		}
   422  	}
   423  	if needRerun {
   424  		m.ResolveRelativeAddresses()
   425  		return
   426  	}
   427  
   428  	var currentOffset int64
   429  	for cur := m.rootInstr; cur != nil; cur = cur.next {
   430  		switch cur.kind {
   431  		case br:
   432  			target := cur.brLabel()
   433  			offsetOfTarget := m.labelPositions[target].binaryOffset
   434  			diff := offsetOfTarget - currentOffset
   435  			divided := diff >> 2
   436  			if divided < minSignedInt26 || divided > maxSignedInt26 {
   437  				// This means the currently compiled single function is extremely large.
   438  				panic("too large function that requires branch relocation of large unconditional branch larger than 26-bit range")
   439  			}
   440  			cur.brOffsetResolve(diff)
   441  		case condBr:
   442  			if !cur.condBrOffsetResolved() {
   443  				target := cur.condBrLabel()
   444  				offsetOfTarget := m.labelPositions[target].binaryOffset
   445  				diff := offsetOfTarget - currentOffset
   446  				if divided := diff >> 2; divided < minSignedInt19 || divided > maxSignedInt19 {
   447  					panic("BUG: branch relocation for large conditional branch larger than 19-bit range must be handled properly")
   448  				}
   449  				cur.condBrOffsetResolve(diff)
   450  			}
   451  		case brTableSequence:
   452  			for i := range cur.targets {
   453  				l := label(cur.targets[i])
   454  				offsetOfTarget := m.labelPositions[l].binaryOffset
   455  				diff := offsetOfTarget - (currentOffset + brTableSequenceOffsetTableBegin)
   456  				cur.targets[i] = uint32(diff)
   457  			}
   458  			cur.brTableSequenceOffsetsResolved()
   459  		case emitSourceOffsetInfo:
   460  			m.compiler.AddSourceOffsetInfo(currentOffset, cur.sourceOffsetInfo())
   461  		}
   462  		currentOffset += cur.size()
   463  	}
   464  }
   465  
   466  const (
   467  	maxSignedInt26 int64 = 1<<25 - 1
   468  	minSignedInt26 int64 = -(1 << 25)
   469  
   470  	maxSignedInt19 int64 = 1<<19 - 1
   471  	minSignedInt19 int64 = -(1 << 19)
   472  )
   473  
   474  func (m *machine) insertConditionalJumpTrampoline(cbr *instruction, currentBlk *labelPosition, nextLabel label) {
   475  	cur := currentBlk.end
   476  	originalTarget := cbr.condBrLabel()
   477  	endNext := cur.next
   478  
   479  	if cur.kind != br {
   480  		// If the current block ends with a conditional branch, we can just insert the trampoline after it.
   481  		// Otherwise, we need to insert "skip" instruction to skip the trampoline instructions.
   482  		skip := m.allocateInstr()
   483  		skip.asBr(nextLabel)
   484  		cur = linkInstr(cur, skip)
   485  	}
   486  
   487  	cbrNewTargetInstr, cbrNewTargetLabel := m.allocateBrTarget()
   488  	cbr.setCondBrTargets(cbrNewTargetLabel)
   489  	cur = linkInstr(cur, cbrNewTargetInstr)
   490  
   491  	// Then insert the unconditional branch to the original, which should be possible to get encoded
   492  	// as 26-bit offset should be enough for any practical application.
   493  	br := m.allocateInstr()
   494  	br.asBr(originalTarget)
   495  	cur = linkInstr(cur, br)
   496  
   497  	// Update the end of the current block.
   498  	currentBlk.end = cur
   499  
   500  	linkInstr(cur, endNext)
   501  }
   502  
   503  func (m *machine) getOrAllocateSSABlockLabel(blk ssa.BasicBlock) label {
   504  	if blk.ReturnBlock() {
   505  		return returnLabel
   506  	}
   507  	l := m.ssaBlockIDToLabels[blk.ID()]
   508  	if l == invalidLabel {
   509  		l = m.allocateLabel()
   510  		m.ssaBlockIDToLabels[blk.ID()] = l
   511  	}
   512  	return l
   513  }
   514  
   515  // LinkAdjacentBlocks implements backend.Machine.
   516  func (m *machine) LinkAdjacentBlocks(prev, next ssa.BasicBlock) {
   517  	prevLabelPos := m.labelPositions[m.getOrAllocateSSABlockLabel(prev)]
   518  	nextLabelPos := m.labelPositions[m.getOrAllocateSSABlockLabel(next)]
   519  	prevLabelPos.end.next = nextLabelPos.begin
   520  }
   521  
   522  // Format implements backend.Machine.
   523  func (m *machine) Format() string {
   524  	begins := map[*instruction]label{}
   525  	for l, pos := range m.labelPositions {
   526  		begins[pos.begin] = l
   527  	}
   528  
   529  	irBlocks := map[label]ssa.BasicBlockID{}
   530  	for i, l := range m.ssaBlockIDToLabels {
   531  		irBlocks[l] = ssa.BasicBlockID(i)
   532  	}
   533  
   534  	var lines []string
   535  	for cur := m.rootInstr; cur != nil; cur = cur.next {
   536  		if l, ok := begins[cur]; ok {
   537  			var labelStr string
   538  			if blkID, ok := irBlocks[l]; ok {
   539  				labelStr = fmt.Sprintf("%s (SSA Block: %s):", l, blkID)
   540  			} else {
   541  				labelStr = fmt.Sprintf("%s:", l)
   542  			}
   543  			lines = append(lines, labelStr)
   544  		}
   545  		if cur.kind == nop0 {
   546  			continue
   547  		}
   548  		lines = append(lines, "\t"+cur.String())
   549  	}
   550  	return "\n" + strings.Join(lines, "\n") + "\n"
   551  }
   552  
   553  // InsertReturn implements backend.Machine.
   554  func (m *machine) InsertReturn() {
   555  	i := m.allocateInstr()
   556  	i.asRet(m.currentABI)
   557  	m.insert(i)
   558  }
   559  
   560  func (m *machine) getVRegSpillSlotOffsetFromSP(id regalloc.VRegID, size byte) int64 {
   561  	offset, ok := m.spillSlots[id]
   562  	if !ok {
   563  		offset = m.spillSlotSize
   564  		// TODO: this should be aligned depending on the `size` to use Imm12 offset load/store as much as possible.
   565  		m.spillSlots[id] = offset
   566  		m.spillSlotSize += int64(size)
   567  	}
   568  	return offset + 16 // spill slot starts above the clobbered registers and the frame size.
   569  }
   570  
   571  func (m *machine) clobberedRegSlotSize() int64 {
   572  	return int64(len(m.clobberedRegs) * 16)
   573  }
   574  
   575  func (m *machine) arg0OffsetFromSP() int64 {
   576  	return m.frameSize() +
   577  		16 + // 16-byte aligned return address
   578  		16 // frame size saved below the clobbered registers.
   579  }
   580  
   581  func (m *machine) ret0OffsetFromSP() int64 {
   582  	return m.arg0OffsetFromSP() + m.currentABI.argStackSize
   583  }
   584  
   585  func (m *machine) requiredStackSize() int64 {
   586  	return m.maxRequiredStackSizeForCalls +
   587  		m.frameSize() +
   588  		16 + // 16-byte aligned return address.
   589  		16 // frame size saved below the clobbered registers.
   590  }
   591  
   592  func (m *machine) frameSize() int64 {
   593  	s := m.clobberedRegSlotSize() + m.spillSlotSize
   594  	if s&0xf != 0 {
   595  		panic(fmt.Errorf("BUG: frame size %d is not 16-byte aligned", s))
   596  	}
   597  	return s
   598  }