github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/engine/wazevo/backend/isa/arm64/lower_instr_operands.go (about)

     1  package arm64
     2  
     3  // This file contains the logic to "find and determine operands" for instructions.
     4  // In order to finalize the form of an operand, we might end up merging/eliminating
     5  // the source instructions into an operand whenever possible.
     6  
     7  import (
     8  	"fmt"
     9  
    10  	"github.com/tetratelabs/wazero/internal/engine/wazevo/backend"
    11  	"github.com/tetratelabs/wazero/internal/engine/wazevo/backend/regalloc"
    12  	"github.com/tetratelabs/wazero/internal/engine/wazevo/ssa"
    13  )
    14  
    15  type (
    16  	// operand represents an operand of an instruction whose type is determined by the kind.
    17  	operand struct {
    18  		kind        operandKind
    19  		data, data2 uint64
    20  	}
    21  	operandKind byte
    22  )
    23  
    24  // Here's the list of operand kinds. We use the abbreviation of the kind name not only for these consts,
    25  // but also names of functions which return the operand of the kind.
    26  const (
    27  	// operandKindNR represents "NormalRegister" (NR). This is literally the register without any special operation unlike others.
    28  	operandKindNR operandKind = iota
    29  	// operandKindSR represents "Shifted Register" (SR). This is a register which is shifted by a constant.
    30  	// Some of the arm64 instructions can take this kind of operand.
    31  	operandKindSR
    32  	// operandKindER represents "Extended Register (ER). This is a register which is sign/zero-extended to a larger size.
    33  	// Some of the arm64 instructions can take this kind of operand.
    34  	operandKindER
    35  	// operandKindImm12 represents "Immediate 12" (Imm12). This is a 12-bit immediate value which can be either shifted by 12 or not.
    36  	// See asImm12 function for detail.
    37  	operandKindImm12
    38  	// operandKindShiftImm represents "Shifted Immediate" (ShiftImm) used by shift operations.
    39  	operandKindShiftImm
    40  )
    41  
    42  // String implements fmt.Stringer for debugging.
    43  func (o operand) format(size byte) string {
    44  	switch o.kind {
    45  	case operandKindNR:
    46  		return formatVRegSized(o.nr(), size)
    47  	case operandKindSR:
    48  		r, amt, sop := o.sr()
    49  		return fmt.Sprintf("%s, %s #%d", formatVRegSized(r, size), sop, amt)
    50  	case operandKindER:
    51  		r, eop, _ := o.er()
    52  		return fmt.Sprintf("%s %s", formatVRegSized(r, size), eop)
    53  	case operandKindImm12:
    54  		imm12, shiftBit := o.imm12()
    55  		if shiftBit == 1 {
    56  			return fmt.Sprintf("#%#x", uint64(imm12)<<12)
    57  		} else {
    58  			return fmt.Sprintf("#%#x", imm12)
    59  		}
    60  	default:
    61  		panic(fmt.Sprintf("unknown operand kind: %d", o.kind))
    62  	}
    63  }
    64  
    65  // operandNR encodes the given VReg as an operand of operandKindNR.
    66  func operandNR(r regalloc.VReg) operand {
    67  	return operand{kind: operandKindNR, data: uint64(r)}
    68  }
    69  
    70  // nr decodes the underlying VReg assuming the operand is of operandKindNR.
    71  func (o operand) nr() regalloc.VReg {
    72  	return regalloc.VReg(o.data)
    73  }
    74  
    75  // operandER encodes the given VReg as an operand of operandKindER.
    76  func operandER(r regalloc.VReg, eop extendOp, to byte) operand {
    77  	if to < 32 {
    78  		panic("TODO?BUG?: when we need to extend to less than 32 bits?")
    79  	}
    80  	return operand{kind: operandKindER, data: uint64(r), data2: uint64(eop)<<32 | uint64(to)}
    81  }
    82  
    83  // er decodes the underlying VReg, extend operation, and the target size assuming the operand is of operandKindER.
    84  func (o operand) er() (r regalloc.VReg, eop extendOp, to byte) {
    85  	return regalloc.VReg(o.data), extendOp(o.data2>>32) & 0xff, byte(o.data2 & 0xff)
    86  }
    87  
    88  // operandSR encodes the given VReg as an operand of operandKindSR.
    89  func operandSR(r regalloc.VReg, amt byte, sop shiftOp) operand {
    90  	return operand{kind: operandKindSR, data: uint64(r), data2: uint64(amt)<<32 | uint64(sop)}
    91  }
    92  
    93  // sr decodes the underlying VReg, shift amount, and shift operation assuming the operand is of operandKindSR.
    94  func (o operand) sr() (r regalloc.VReg, amt byte, sop shiftOp) {
    95  	return regalloc.VReg(o.data), byte(o.data2>>32) & 0xff, shiftOp(o.data2) & 0xff
    96  }
    97  
    98  // operandImm12 encodes the given imm12 as an operand of operandKindImm12.
    99  func operandImm12(imm12 uint16, shiftBit byte) operand {
   100  	return operand{kind: operandKindImm12, data: uint64(imm12) | uint64(shiftBit)<<32}
   101  }
   102  
   103  // imm12 decodes the underlying imm12 data assuming the operand is of operandKindImm12.
   104  func (o operand) imm12() (v uint16, shiftBit byte) {
   105  	return uint16(o.data), byte(o.data >> 32)
   106  }
   107  
   108  // operandShiftImm encodes the given amount as an operand of operandKindShiftImm.
   109  func operandShiftImm(amount byte) operand {
   110  	return operand{kind: operandKindShiftImm, data: uint64(amount)}
   111  }
   112  
   113  // shiftImm decodes the underlying shift amount data assuming the operand is of operandKindShiftImm.
   114  func (o operand) shiftImm() byte {
   115  	return byte(o.data)
   116  }
   117  
   118  // reg returns the register of the operand if applicable.
   119  func (o operand) reg() regalloc.VReg {
   120  	switch o.kind {
   121  	case operandKindNR:
   122  		return o.nr()
   123  	case operandKindSR:
   124  		r, _, _ := o.sr()
   125  		return r
   126  	case operandKindER:
   127  		r, _, _ := o.er()
   128  		return r
   129  	case operandKindImm12:
   130  		// Does not have a register.
   131  	case operandKindShiftImm:
   132  		// Does not have a register.
   133  	default:
   134  		panic(o.kind)
   135  	}
   136  	return regalloc.VRegInvalid
   137  }
   138  
   139  func (o operand) realReg() regalloc.RealReg {
   140  	return o.nr().RealReg()
   141  }
   142  
   143  func (o operand) assignReg(v regalloc.VReg) operand {
   144  	switch o.kind {
   145  	case operandKindNR:
   146  		return operandNR(v)
   147  	case operandKindSR:
   148  		_, amt, sop := o.sr()
   149  		return operandSR(v, amt, sop)
   150  	case operandKindER:
   151  		_, eop, to := o.er()
   152  		return operandER(v, eop, to)
   153  	case operandKindImm12:
   154  		// Does not have a register.
   155  	case operandKindShiftImm:
   156  		// Does not have a register.
   157  	}
   158  	panic(o.kind)
   159  }
   160  
   161  // ensureValueNR returns an operand of either operandKindER, operandKindSR, or operandKindNR from the given value (defined by `def).
   162  //
   163  // `mode` is used to extend the operand if the bit length is smaller than mode.bits().
   164  // If the operand can be expressed as operandKindImm12, `mode` is ignored.
   165  func (m *machine) getOperand_Imm12_ER_SR_NR(def *backend.SSAValueDefinition, mode extMode) (op operand) {
   166  	if def.IsFromBlockParam() {
   167  		return operandNR(def.BlkParamVReg)
   168  	}
   169  
   170  	instr := def.Instr
   171  	if instr.Opcode() == ssa.OpcodeIconst {
   172  		if imm12Op, ok := asImm12Operand(instr.ConstantVal()); ok {
   173  			instr.MarkLowered()
   174  			return imm12Op
   175  		}
   176  	}
   177  	return m.getOperand_ER_SR_NR(def, mode)
   178  }
   179  
   180  // getOperand_MaybeNegatedImm12_ER_SR_NR is almost the same as getOperand_Imm12_ER_SR_NR, but this might negate the immediate value.
   181  // If the immediate value is negated, the second return value is true, otherwise always false.
   182  func (m *machine) getOperand_MaybeNegatedImm12_ER_SR_NR(def *backend.SSAValueDefinition, mode extMode) (op operand, negatedImm12 bool) {
   183  	if def.IsFromBlockParam() {
   184  		return operandNR(def.BlkParamVReg), false
   185  	}
   186  
   187  	instr := def.Instr
   188  	if instr.Opcode() == ssa.OpcodeIconst {
   189  		c := instr.ConstantVal()
   190  		if imm12Op, ok := asImm12Operand(c); ok {
   191  			instr.MarkLowered()
   192  			return imm12Op, false
   193  		}
   194  
   195  		signExtended := int64(c)
   196  		if def.SSAValue().Type().Bits() == 32 {
   197  			signExtended = (signExtended << 32) >> 32
   198  		}
   199  		negatedWithoutSign := -signExtended
   200  		if imm12Op, ok := asImm12Operand(uint64(negatedWithoutSign)); ok {
   201  			instr.MarkLowered()
   202  			return imm12Op, true
   203  		}
   204  	}
   205  	return m.getOperand_ER_SR_NR(def, mode), false
   206  }
   207  
   208  // ensureValueNR returns an operand of either operandKindER, operandKindSR, or operandKindNR from the given value (defined by `def).
   209  //
   210  // `mode` is used to extend the operand if the bit length is smaller than mode.bits().
   211  func (m *machine) getOperand_ER_SR_NR(def *backend.SSAValueDefinition, mode extMode) (op operand) {
   212  	if def.IsFromBlockParam() {
   213  		return operandNR(def.BlkParamVReg)
   214  	}
   215  
   216  	if m.compiler.MatchInstr(def, ssa.OpcodeSExtend) || m.compiler.MatchInstr(def, ssa.OpcodeUExtend) {
   217  		extInstr := def.Instr
   218  
   219  		signed := extInstr.Opcode() == ssa.OpcodeSExtend
   220  		innerExtFromBits, innerExtToBits := extInstr.ExtendFromToBits()
   221  		modeBits, modeSigned := mode.bits(), mode.signed()
   222  		if mode == extModeNone || innerExtToBits == modeBits {
   223  			eop := extendOpFrom(signed, innerExtFromBits)
   224  			extArg := m.getOperand_NR(m.compiler.ValueDefinition(extInstr.Arg()), extModeNone)
   225  			op = operandER(extArg.nr(), eop, innerExtToBits)
   226  			extInstr.MarkLowered()
   227  			return
   228  		}
   229  
   230  		if innerExtToBits > modeBits {
   231  			panic("BUG?TODO?: need the results of inner extension to be larger than the mode")
   232  		}
   233  
   234  		switch {
   235  		case (!signed && !modeSigned) || (signed && modeSigned):
   236  			// Two sign/zero extensions are equivalent to one sign/zero extension for the larger size.
   237  			eop := extendOpFrom(modeSigned, innerExtFromBits)
   238  			op = operandER(m.compiler.VRegOf(extInstr.Arg()), eop, modeBits)
   239  			extInstr.MarkLowered()
   240  		case (signed && !modeSigned) || (!signed && modeSigned):
   241  			// We need to {sign, zero}-extend the result of the {zero,sign} extension.
   242  			eop := extendOpFrom(modeSigned, innerExtToBits)
   243  			op = operandER(m.compiler.VRegOf(extInstr.Return()), eop, modeBits)
   244  			// Note that we failed to merge the inner extension instruction this case.
   245  		}
   246  		return
   247  	}
   248  	return m.getOperand_SR_NR(def, mode)
   249  }
   250  
   251  // ensureValueNR returns an operand of either operandKindSR or operandKindNR from the given value (defined by `def).
   252  //
   253  // `mode` is used to extend the operand if the bit length is smaller than mode.bits().
   254  func (m *machine) getOperand_SR_NR(def *backend.SSAValueDefinition, mode extMode) (op operand) {
   255  	if def.IsFromBlockParam() {
   256  		return operandNR(def.BlkParamVReg)
   257  	}
   258  
   259  	if m.compiler.MatchInstr(def, ssa.OpcodeIshl) {
   260  		// Check if the shift amount is constant instruction.
   261  		targetVal, amountVal := def.Instr.Arg2()
   262  		targetVReg := m.getOperand_NR(m.compiler.ValueDefinition(targetVal), extModeNone).nr()
   263  		amountDef := m.compiler.ValueDefinition(amountVal)
   264  		if amountDef.IsFromInstr() && amountDef.Instr.Constant() {
   265  			// If that is the case, we can use the shifted register operand (SR).
   266  			c := byte(amountDef.Instr.ConstantVal()) & (targetVal.Type().Bits() - 1) // Clears the unnecessary bits.
   267  			def.Instr.MarkLowered()
   268  			amountDef.Instr.MarkLowered()
   269  			return operandSR(targetVReg, c, shiftOpLSL)
   270  		}
   271  	}
   272  	return m.getOperand_NR(def, mode)
   273  }
   274  
   275  // getOperand_ShiftImm_NR returns an operand of either operandKindShiftImm or operandKindNR from the given value (defined by `def).
   276  func (m *machine) getOperand_ShiftImm_NR(def *backend.SSAValueDefinition, mode extMode, shiftBitWidth byte) (op operand) {
   277  	if def.IsFromBlockParam() {
   278  		return operandNR(def.BlkParamVReg)
   279  	}
   280  
   281  	instr := def.Instr
   282  	if instr.Constant() {
   283  		amount := byte(instr.ConstantVal()) & (shiftBitWidth - 1) // Clears the unnecessary bits.
   284  		return operandShiftImm(amount)
   285  	}
   286  	return m.getOperand_NR(def, mode)
   287  }
   288  
   289  // ensureValueNR returns an operand of operandKindNR from the given value (defined by `def).
   290  //
   291  // `mode` is used to extend the operand if the bit length is smaller than mode.bits().
   292  func (m *machine) getOperand_NR(def *backend.SSAValueDefinition, mode extMode) (op operand) {
   293  	var v regalloc.VReg
   294  	if def.IsFromBlockParam() {
   295  		v = def.BlkParamVReg
   296  	} else {
   297  		instr := def.Instr
   298  		if instr.Constant() {
   299  			// We inline all the constant instructions so that we could reduce the register usage.
   300  			v = m.lowerConstant(instr)
   301  			instr.MarkLowered()
   302  		} else {
   303  			if n := def.N; n == 0 {
   304  				v = m.compiler.VRegOf(instr.Return())
   305  			} else {
   306  				_, rs := instr.Returns()
   307  				v = m.compiler.VRegOf(rs[n-1])
   308  			}
   309  		}
   310  	}
   311  
   312  	r := v
   313  	switch inBits := def.SSAValue().Type().Bits(); {
   314  	case mode == extModeNone:
   315  	case inBits == 32 && (mode == extModeZeroExtend32 || mode == extModeSignExtend32):
   316  	case inBits == 32 && mode == extModeZeroExtend64:
   317  		extended := m.compiler.AllocateVReg(ssa.TypeI64)
   318  		ext := m.allocateInstr()
   319  		ext.asExtend(extended, v, 32, 64, false)
   320  		m.insert(ext)
   321  		r = extended
   322  	case inBits == 32 && mode == extModeSignExtend64:
   323  		extended := m.compiler.AllocateVReg(ssa.TypeI64)
   324  		ext := m.allocateInstr()
   325  		ext.asExtend(extended, v, 32, 64, true)
   326  		m.insert(ext)
   327  		r = extended
   328  	case inBits == 64 && (mode == extModeZeroExtend64 || mode == extModeSignExtend64):
   329  	}
   330  	return operandNR(r)
   331  }
   332  
   333  func asImm12Operand(val uint64) (op operand, ok bool) {
   334  	v, shiftBit, ok := asImm12(val)
   335  	if !ok {
   336  		return operand{}, false
   337  	}
   338  	return operandImm12(v, shiftBit), true
   339  }
   340  
   341  func asImm12(val uint64) (v uint16, shiftBit byte, ok bool) {
   342  	const mask1, mask2 uint64 = 0xfff, 0xfff_000
   343  	if val&^mask1 == 0 {
   344  		return uint16(val), 0, true
   345  	} else if val&^mask2 == 0 {
   346  		return uint16(val >> 12), 1, true
   347  	} else {
   348  		return 0, 0, false
   349  	}
   350  }