github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/engine/wazevo/backend/isa/arm64/lower_mem.go (about)

     1  package arm64
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/tetratelabs/wazero/internal/engine/wazevo/backend/regalloc"
     7  	"github.com/tetratelabs/wazero/internal/engine/wazevo/ssa"
     8  	"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
     9  )
    10  
    11  type (
    12  	// addressMode represents an ARM64 addressing mode.
    13  	//
    14  	// https://developer.arm.com/documentation/102374/0101/Loads-and-stores---addressing
    15  	// TODO: use the bit-packed layout like operand struct.
    16  	addressMode struct {
    17  		kind   addressModeKind
    18  		rn, rm regalloc.VReg
    19  		extOp  extendOp
    20  		imm    int64
    21  	}
    22  
    23  	// addressModeKind represents the kind of ARM64 addressing mode.
    24  	addressModeKind byte
    25  )
    26  
    27  const (
    28  	// addressModeKindRegExtended takes a base register and an index register. The index register is sign/zero-extended,
    29  	// and then scaled by bits(type)/8.
    30  	//
    31  	// e.g.
    32  	// 	- ldrh w1, [x2, w3, SXTW #1] ;; sign-extended and scaled by 2 (== LSL #1)
    33  	// 	- strh w1, [x2, w3, UXTW #1] ;; zero-extended and scaled by 2 (== LSL #1)
    34  	// 	- ldr w1, [x2, w3, SXTW #2] ;; sign-extended and scaled by 4 (== LSL #2)
    35  	// 	- str x1, [x2, w3, UXTW #3] ;; zero-extended and scaled by 8 (== LSL #3)
    36  	//
    37  	// See the following pages:
    38  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDRH--register---Load-Register-Halfword--register--
    39  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDR--register---Load-Register--register--
    40  	addressModeKindRegScaledExtended addressModeKind = iota
    41  
    42  	// addressModeKindRegScaled is the same as addressModeKindRegScaledExtended, but without extension factor.
    43  	addressModeKindRegScaled
    44  
    45  	// addressModeKindRegScaled is the same as addressModeKindRegScaledExtended, but without scale factor.
    46  	addressModeKindRegExtended
    47  
    48  	// addressModeKindRegReg takes a base register and an index register. The index register is not either scaled or extended.
    49  	addressModeKindRegReg
    50  
    51  	// addressModeKindRegSignedImm9 takes a base register and a 9-bit "signed" immediate offset (-256 to 255).
    52  	// The immediate will be sign-extended, and be added to the base register.
    53  	// This is a.k.a. "unscaled" since the immediate is not scaled.
    54  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDUR--Load-Register--unscaled--
    55  	addressModeKindRegSignedImm9
    56  
    57  	// addressModeKindRegUnsignedImm12 takes a base register and a 12-bit "unsigned" immediate offset.  scaled by
    58  	// the size of the type. In other words, the actual offset will be imm12 * bits(type)/8.
    59  	// See "Unsigned offset" in the following pages:
    60  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDRB--immediate---Load-Register-Byte--immediate--
    61  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDRH--immediate---Load-Register-Halfword--immediate--
    62  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDR--immediate---Load-Register--immediate--
    63  	addressModeKindRegUnsignedImm12
    64  
    65  	// addressModePostIndex takes a base register and a 9-bit "signed" immediate offset.
    66  	// After the load/store, the base register will be updated by the offset.
    67  	//
    68  	// Note that when this is used for pair load/store, the offset will be 7-bit "signed" immediate offset.
    69  	//
    70  	// See "Post-index" in the following pages for examples:
    71  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDRB--immediate---Load-Register-Byte--immediate--
    72  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDRH--immediate---Load-Register-Halfword--immediate--
    73  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDR--immediate---Load-Register--immediate--
    74  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDP--Load-Pair-of-Registers-
    75  	addressModeKindPostIndex
    76  
    77  	// addressModePostIndex takes a base register and a 9-bit "signed" immediate offset.
    78  	// Before the load/store, the base register will be updated by the offset.
    79  	//
    80  	// Note that when this is used for pair load/store, the offset will be 7-bit "signed" immediate offset.
    81  	//
    82  	// See "Pre-index" in the following pages for examples:
    83  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDRB--immediate---Load-Register-Byte--immediate--
    84  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDRH--immediate---Load-Register-Halfword--immediate--
    85  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDR--immediate---Load-Register--immediate--
    86  	// https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/LDP--Load-Pair-of-Registers-
    87  	addressModeKindPreIndex
    88  
    89  	// addressModeKindArgStackSpace is used to resolve the address of the argument stack space
    90  	// exiting right above the stack pointer. Since we don't know the exact stack space needed for a function
    91  	// at a compilation phase, this is used as a placeholder and further lowered to a real addressing mode like above.
    92  	addressModeKindArgStackSpace
    93  
    94  	// addressModeKindResultStackSpace is used to resolve the address of the result stack space
    95  	// exiting right above the stack pointer. Since we don't know the exact stack space needed for a function
    96  	// at a compilation phase, this is used as a placeholder and further lowered to a real addressing mode like above.
    97  	addressModeKindResultStackSpace
    98  )
    99  
   100  func (a addressMode) format(dstSizeBits byte) (ret string) {
   101  	base := formatVRegSized(a.rn, 64)
   102  	if rn := a.rn; rn.RegType() != regalloc.RegTypeInt {
   103  		panic("invalid base register type: " + a.rn.RegType().String())
   104  	} else if rn.IsRealReg() && v0 <= a.rn.RealReg() && a.rn.RealReg() <= v30 {
   105  		panic("BUG: likely a bug in reg alloc or reset behavior")
   106  	}
   107  
   108  	switch a.kind {
   109  	case addressModeKindRegScaledExtended:
   110  		amount := a.sizeInBitsToShiftAmount(dstSizeBits)
   111  		ret = fmt.Sprintf("[%s, %s, %s #%#x]", base, formatVRegSized(a.rm, a.indexRegBits()), a.extOp, amount)
   112  	case addressModeKindRegScaled:
   113  		amount := a.sizeInBitsToShiftAmount(dstSizeBits)
   114  		ret = fmt.Sprintf("[%s, %s, lsl #%#x]", base, formatVRegSized(a.rm, a.indexRegBits()), amount)
   115  	case addressModeKindRegExtended:
   116  		ret = fmt.Sprintf("[%s, %s, %s]", base, formatVRegSized(a.rm, a.indexRegBits()), a.extOp)
   117  	case addressModeKindRegReg:
   118  		ret = fmt.Sprintf("[%s, %s]", base, formatVRegSized(a.rm, a.indexRegBits()))
   119  	case addressModeKindRegSignedImm9:
   120  		if a.imm != 0 {
   121  			ret = fmt.Sprintf("[%s, #%#x]", base, a.imm)
   122  		} else {
   123  			ret = fmt.Sprintf("[%s]", base)
   124  		}
   125  	case addressModeKindRegUnsignedImm12:
   126  		if a.imm != 0 {
   127  			ret = fmt.Sprintf("[%s, #%#x]", base, a.imm)
   128  		} else {
   129  			ret = fmt.Sprintf("[%s]", base)
   130  		}
   131  	case addressModeKindPostIndex:
   132  		ret = fmt.Sprintf("[%s], #%#x", base, a.imm)
   133  	case addressModeKindPreIndex:
   134  		ret = fmt.Sprintf("[%s, #%#x]!", base, a.imm)
   135  	case addressModeKindArgStackSpace:
   136  		ret = fmt.Sprintf("[#arg_space, #%#x]", a.imm)
   137  	case addressModeKindResultStackSpace:
   138  		ret = fmt.Sprintf("[#ret_space, #%#x]", a.imm)
   139  	}
   140  	return
   141  }
   142  
   143  func addressModePreOrPostIndex(rn regalloc.VReg, imm int64, preIndex bool) addressMode {
   144  	if !offsetFitsInAddressModeKindRegSignedImm9(imm) {
   145  		panic(fmt.Sprintf("BUG: offset %#x does not fit in addressModeKindRegSignedImm9", imm))
   146  	}
   147  	if preIndex {
   148  		return addressMode{kind: addressModeKindPreIndex, rn: rn, imm: imm}
   149  	} else {
   150  		return addressMode{kind: addressModeKindPostIndex, rn: rn, imm: imm}
   151  	}
   152  }
   153  
   154  func offsetFitsInAddressModeKindRegUnsignedImm12(dstSizeInBits byte, offset int64) bool {
   155  	divisor := int64(dstSizeInBits) / 8
   156  	return 0 < offset && offset%divisor == 0 && offset/divisor < 4096
   157  }
   158  
   159  func offsetFitsInAddressModeKindRegSignedImm9(offset int64) bool {
   160  	return -256 <= offset && offset <= 255
   161  }
   162  
   163  func (a addressMode) indexRegBits() byte {
   164  	bits := a.extOp.srcBits()
   165  	if bits != 32 && bits != 64 {
   166  		panic("invalid index register for address mode. it must be either 32 or 64 bits")
   167  	}
   168  	return bits
   169  }
   170  
   171  func (a addressMode) sizeInBitsToShiftAmount(sizeInBits byte) (lsl byte) {
   172  	switch sizeInBits {
   173  	case 8:
   174  		lsl = 0
   175  	case 16:
   176  		lsl = 1
   177  	case 32:
   178  		lsl = 2
   179  	case 64:
   180  		lsl = 3
   181  	}
   182  	return
   183  }
   184  
   185  func extLoadSignSize(op ssa.Opcode) (size byte, signed bool) {
   186  	switch op {
   187  	case ssa.OpcodeUload8:
   188  		size, signed = 8, false
   189  	case ssa.OpcodeUload16:
   190  		size, signed = 16, false
   191  	case ssa.OpcodeUload32:
   192  		size, signed = 32, false
   193  	case ssa.OpcodeSload8:
   194  		size, signed = 8, true
   195  	case ssa.OpcodeSload16:
   196  		size, signed = 16, true
   197  	case ssa.OpcodeSload32:
   198  		size, signed = 32, true
   199  	default:
   200  		panic("BUG")
   201  	}
   202  	return
   203  }
   204  
   205  func (m *machine) lowerExtLoad(op ssa.Opcode, ptr ssa.Value, offset uint32, ret regalloc.VReg) {
   206  	size, signed := extLoadSignSize(op)
   207  	amode := m.lowerToAddressMode(ptr, offset, size)
   208  	load := m.allocateInstr()
   209  	if signed {
   210  		load.asSLoad(operandNR(ret), amode, size)
   211  	} else {
   212  		load.asULoad(operandNR(ret), amode, size)
   213  	}
   214  	m.insert(load)
   215  }
   216  
   217  func (m *machine) lowerLoad(ptr ssa.Value, offset uint32, typ ssa.Type, ret ssa.Value) {
   218  	amode := m.lowerToAddressMode(ptr, offset, typ.Bits())
   219  
   220  	dst := m.compiler.VRegOf(ret)
   221  	load := m.allocateInstr()
   222  	switch typ {
   223  	case ssa.TypeI32, ssa.TypeI64:
   224  		load.asULoad(operandNR(dst), amode, typ.Bits())
   225  	case ssa.TypeF32, ssa.TypeF64:
   226  		load.asFpuLoad(operandNR(dst), amode, typ.Bits())
   227  	case ssa.TypeV128:
   228  		load.asFpuLoad(operandNR(dst), amode, 128)
   229  	default:
   230  		panic("TODO")
   231  	}
   232  	m.insert(load)
   233  }
   234  
   235  func (m *machine) lowerLoadSplat(ptr ssa.Value, offset uint32, lane ssa.VecLane, ret ssa.Value) {
   236  	// vecLoad1R has offset address mode (base+imm) only for post index, so we simply add the offset to the base.
   237  	base := m.getOperand_NR(m.compiler.ValueDefinition(ptr), extModeNone).nr()
   238  	offsetReg := m.compiler.AllocateVReg(ssa.TypeI64)
   239  	m.lowerConstantI64(offsetReg, int64(offset))
   240  	addedBase := m.addReg64ToReg64(base, offsetReg)
   241  
   242  	rd := operandNR(m.compiler.VRegOf(ret))
   243  
   244  	ld1r := m.allocateInstr()
   245  	ld1r.asVecLoad1R(rd, operandNR(addedBase), ssaLaneToArrangement(lane))
   246  	m.insert(ld1r)
   247  }
   248  
   249  func (m *machine) lowerStore(si *ssa.Instruction) {
   250  	// TODO: merge consecutive stores into a single pair store instruction.
   251  	value, ptr, offset, storeSizeInBits := si.StoreData()
   252  	amode := m.lowerToAddressMode(ptr, offset, storeSizeInBits)
   253  
   254  	valueOp := m.getOperand_NR(m.compiler.ValueDefinition(value), extModeNone)
   255  	store := m.allocateInstr()
   256  	store.asStore(valueOp, amode, storeSizeInBits)
   257  	m.insert(store)
   258  }
   259  
   260  // lowerToAddressMode converts a pointer to an addressMode that can be used as an operand for load/store instructions.
   261  func (m *machine) lowerToAddressMode(ptr ssa.Value, offsetBase uint32, size byte) (amode addressMode) {
   262  	// TODO: currently the instruction selection logic doesn't support addressModeKindRegScaledExtended and
   263  	// addressModeKindRegScaled since collectAddends doesn't take ssa.OpcodeIshl into account. This should be fixed
   264  	// to support more efficient address resolution.
   265  
   266  	a32s, a64s, offset := m.collectAddends(ptr)
   267  	offset += int64(offsetBase)
   268  	return m.lowerToAddressModeFromAddends(a32s, a64s, size, offset)
   269  }
   270  
   271  // lowerToAddressModeFromAddends creates an addressMode from a list of addends collected by collectAddends.
   272  // During the construction, this might emit additional instructions.
   273  //
   274  // Extracted as a separate function for easy testing.
   275  func (m *machine) lowerToAddressModeFromAddends(a32s *wazevoapi.Queue[addend32], a64s *wazevoapi.Queue[regalloc.VReg], size byte, offset int64) (amode addressMode) {
   276  	switch a64sExist, a32sExist := !a64s.Empty(), !a32s.Empty(); {
   277  	case a64sExist && a32sExist:
   278  		var base regalloc.VReg
   279  		base = a64s.Dequeue()
   280  		var a32 addend32
   281  		a32 = a32s.Dequeue()
   282  		amode = addressMode{kind: addressModeKindRegExtended, rn: base, rm: a32.r, extOp: a32.ext}
   283  	case a64sExist && offsetFitsInAddressModeKindRegUnsignedImm12(size, offset):
   284  		var base regalloc.VReg
   285  		base = a64s.Dequeue()
   286  		amode = addressMode{kind: addressModeKindRegUnsignedImm12, rn: base, imm: offset}
   287  		offset = 0
   288  	case a64sExist && offsetFitsInAddressModeKindRegSignedImm9(offset):
   289  		var base regalloc.VReg
   290  		base = a64s.Dequeue()
   291  		amode = addressMode{kind: addressModeKindRegSignedImm9, rn: base, imm: offset}
   292  		offset = 0
   293  	case a64sExist:
   294  		var base regalloc.VReg
   295  		base = a64s.Dequeue()
   296  		if !a64s.Empty() {
   297  			index := a64s.Dequeue()
   298  			amode = addressMode{kind: addressModeKindRegReg, rn: base, rm: index, extOp: extendOpUXTX /* indicates index reg is 64-bit */}
   299  		} else {
   300  			amode = addressMode{kind: addressModeKindRegUnsignedImm12, rn: base, imm: 0}
   301  		}
   302  	case a32sExist:
   303  		base32 := a32s.Dequeue()
   304  
   305  		// First we need 64-bit base.
   306  		base := m.compiler.AllocateVReg(ssa.TypeI64)
   307  		baseExt := m.allocateInstr()
   308  		var signed bool
   309  		if base32.ext == extendOpSXTW {
   310  			signed = true
   311  		}
   312  		baseExt.asExtend(base, base32.r, 32, 64, signed)
   313  		m.insert(baseExt)
   314  
   315  		if !a32s.Empty() {
   316  			index := a32s.Dequeue()
   317  			amode = addressMode{kind: addressModeKindRegExtended, rn: base, rm: index.r, extOp: index.ext}
   318  		} else {
   319  			amode = addressMode{kind: addressModeKindRegUnsignedImm12, rn: base, imm: 0}
   320  		}
   321  	default: // Only static offsets.
   322  		tmpReg := m.compiler.AllocateVReg(ssa.TypeI64)
   323  		m.lowerConstantI64(tmpReg, offset)
   324  		amode = addressMode{kind: addressModeKindRegUnsignedImm12, rn: tmpReg, imm: 0}
   325  		offset = 0
   326  	}
   327  
   328  	baseReg := amode.rn
   329  	if offset > 0 {
   330  		baseReg = m.addConstToReg64(baseReg, offset) // baseReg += offset
   331  	}
   332  
   333  	for !a64s.Empty() {
   334  		a64 := a64s.Dequeue()
   335  		baseReg = m.addReg64ToReg64(baseReg, a64) // baseReg += a64
   336  	}
   337  
   338  	for !a32s.Empty() {
   339  		a32 := a32s.Dequeue()
   340  		baseReg = m.addRegToReg64Ext(baseReg, a32.r, a32.ext) // baseReg += (a32 extended to 64-bit)
   341  	}
   342  	amode.rn = baseReg
   343  	return
   344  }
   345  
   346  var addendsMatchOpcodes = [4]ssa.Opcode{ssa.OpcodeUExtend, ssa.OpcodeSExtend, ssa.OpcodeIadd, ssa.OpcodeIconst}
   347  
   348  func (m *machine) collectAddends(ptr ssa.Value) (addends32 *wazevoapi.Queue[addend32], addends64 *wazevoapi.Queue[regalloc.VReg], offset int64) {
   349  	m.addendsWorkQueue.Reset()
   350  	m.addends32.Reset()
   351  	m.addends64.Reset()
   352  	m.addendsWorkQueue.Enqueue(ptr)
   353  
   354  	for !m.addendsWorkQueue.Empty() {
   355  		v := m.addendsWorkQueue.Dequeue()
   356  
   357  		def := m.compiler.ValueDefinition(v)
   358  		switch op := m.compiler.MatchInstrOneOf(def, addendsMatchOpcodes[:]); op {
   359  		case ssa.OpcodeIadd:
   360  			// If the addend is an add, we recursively collect its operands.
   361  			x, y := def.Instr.Arg2()
   362  			m.addendsWorkQueue.Enqueue(x)
   363  			m.addendsWorkQueue.Enqueue(y)
   364  			def.Instr.MarkLowered()
   365  		case ssa.OpcodeIconst:
   366  			// If the addend is constant, we just statically merge it into the offset.
   367  			ic := def.Instr
   368  			u64 := ic.ConstantVal()
   369  			if ic.Return().Type().Bits() == 32 {
   370  				offset += int64(int32(u64)) // sign-extend.
   371  			} else {
   372  				offset += int64(u64)
   373  			}
   374  			def.Instr.MarkLowered()
   375  		case ssa.OpcodeUExtend, ssa.OpcodeSExtend:
   376  			input := def.Instr.Arg()
   377  			if input.Type().Bits() != 32 {
   378  				panic("illegal size: " + input.Type().String())
   379  			}
   380  
   381  			var ext extendOp
   382  			if op == ssa.OpcodeUExtend {
   383  				ext = extendOpUXTW
   384  			} else {
   385  				ext = extendOpSXTW
   386  			}
   387  
   388  			inputDef := m.compiler.ValueDefinition(input)
   389  			constInst := inputDef.IsFromInstr() && inputDef.Instr.Constant()
   390  			switch {
   391  			case constInst && ext == extendOpUXTW:
   392  				// Zero-extension of a 32-bit constant can be merged into the offset.
   393  				offset += int64(uint32(inputDef.Instr.ConstantVal()))
   394  			case constInst && ext == extendOpSXTW:
   395  				// Sign-extension of a 32-bit constant can be merged into the offset.
   396  				offset += int64(int32(inputDef.Instr.ConstantVal())) // sign-extend!
   397  			default:
   398  				m.addends32.Enqueue(addend32{r: m.getOperand_NR(inputDef, extModeNone).nr(), ext: ext})
   399  			}
   400  			def.Instr.MarkLowered()
   401  			continue
   402  		default:
   403  			// If the addend is not one of them, we simply use it as-is (without merging!), optionally zero-extending it.
   404  			m.addends64.Enqueue(m.getOperand_NR(def, extModeZeroExtend64 /* optional zero ext */).nr())
   405  		}
   406  	}
   407  	return &m.addends32, &m.addends64, offset
   408  }
   409  
   410  func (m *machine) addConstToReg64(r regalloc.VReg, c int64) (rd regalloc.VReg) {
   411  	rd = m.compiler.AllocateVReg(ssa.TypeI64)
   412  	alu := m.allocateInstr()
   413  	if imm12Op, ok := asImm12Operand(uint64(c)); ok {
   414  		alu.asALU(aluOpAdd, operandNR(rd), operandNR(r), imm12Op, true)
   415  	} else if imm12Op, ok = asImm12Operand(uint64(-c)); ok {
   416  		alu.asALU(aluOpSub, operandNR(rd), operandNR(r), imm12Op, true)
   417  	} else {
   418  		tmp := m.compiler.AllocateVReg(ssa.TypeI64)
   419  		m.load64bitConst(c, tmp)
   420  		alu.asALU(aluOpAdd, operandNR(rd), operandNR(r), operandNR(tmp), true)
   421  	}
   422  	m.insert(alu)
   423  	return
   424  }
   425  
   426  func (m *machine) addReg64ToReg64(rn, rm regalloc.VReg) (rd regalloc.VReg) {
   427  	rd = m.compiler.AllocateVReg(ssa.TypeI64)
   428  	alu := m.allocateInstr()
   429  	alu.asALU(aluOpAdd, operandNR(rd), operandNR(rn), operandNR(rm), true)
   430  	m.insert(alu)
   431  	return
   432  }
   433  
   434  func (m *machine) addRegToReg64Ext(rn, rm regalloc.VReg, ext extendOp) (rd regalloc.VReg) {
   435  	rd = m.compiler.AllocateVReg(ssa.TypeI64)
   436  	alu := m.allocateInstr()
   437  	alu.asALU(aluOpAdd, operandNR(rd), operandNR(rn), operandER(rm, ext, 64), true)
   438  	m.insert(alu)
   439  	return
   440  }