github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/wazevo/backend/isa/arm64/abi_go_call.go (about)

     1  package arm64
     2  
     3  import (
     4  	"github.com/bananabytelabs/wazero/internal/engine/wazevo/backend"
     5  	"github.com/bananabytelabs/wazero/internal/engine/wazevo/backend/regalloc"
     6  	"github.com/bananabytelabs/wazero/internal/engine/wazevo/ssa"
     7  	"github.com/bananabytelabs/wazero/internal/engine/wazevo/wazevoapi"
     8  )
     9  
    10  var calleeSavedRegistersSorted = []regalloc.VReg{
    11  	x19VReg, x20VReg, x21VReg, x22VReg, x23VReg, x24VReg, x25VReg, x26VReg, x28VReg,
    12  	v18VReg, v19VReg, v20VReg, v21VReg, v22VReg, v23VReg, v24VReg, v25VReg, v26VReg, v27VReg, v28VReg, v29VReg, v30VReg, v31VReg,
    13  }
    14  
    15  // CompileGoFunctionTrampoline implements backend.Machine.
    16  func (m *machine) CompileGoFunctionTrampoline(exitCode wazevoapi.ExitCode, sig *ssa.Signature, needModuleContextPtr bool) []byte {
    17  	exct := m.executableContext
    18  	argBegin := 1 // Skips exec context by default.
    19  	if needModuleContextPtr {
    20  		argBegin++
    21  	}
    22  
    23  	abi := &abiImpl{m: m}
    24  	abi.init(sig)
    25  	m.currentABI = abi
    26  
    27  	cur := m.allocateInstr()
    28  	cur.asNop0()
    29  	exct.RootInstr = cur
    30  
    31  	// Execution context is always the first argument.
    32  	execCtrPtr := x0VReg
    33  
    34  	// In the following, we create the following stack layout:
    35  	//
    36  	//                   (high address)
    37  	//     SP ------> +-----------------+  <----+
    38  	//                |     .......     |       |
    39  	//                |      ret Y      |       |
    40  	//                |     .......     |       |
    41  	//                |      ret 0      |       |
    42  	//                |      arg X      |       |  size_of_arg_ret
    43  	//                |     .......     |       |
    44  	//                |      arg 1      |       |
    45  	//                |      arg 0      |  <----+ <-------- originalArg0Reg
    46  	//                | size_of_arg_ret |
    47  	//                |  ReturnAddress  |
    48  	//                +-----------------+ <----+
    49  	//                |      xxxx       |      |  ;; might be padded to make it 16-byte aligned.
    50  	//           +--->|  arg[N]/ret[M]  |      |
    51  	//  sliceSize|    |   ............  |      | goCallStackSize
    52  	//           |    |  arg[1]/ret[1]  |      |
    53  	//           +--->|  arg[0]/ret[0]  | <----+ <-------- arg0ret0AddrReg
    54  	//                |    sliceSize    |
    55  	//                |   frame_size    |
    56  	//                +-----------------+
    57  	//                   (low address)
    58  	//
    59  	// where the region of "arg[0]/ret[0] ... arg[N]/ret[M]" is the stack used by the Go functions,
    60  	// therefore will be accessed as the usual []uint64. So that's where we need to pass/receive
    61  	// the arguments/return values.
    62  
    63  	// First of all, to update the SP, and create "ReturnAddress + size_of_arg_ret".
    64  	cur = m.createReturnAddrAndSizeOfArgRetSlot(cur)
    65  
    66  	const frameInfoSize = 16 // == frame_size + sliceSize.
    67  
    68  	// Next, we should allocate the stack for the Go function call if necessary.
    69  	goCallStackSize, sliceSizeInBytes := goFunctionCallRequiredStackSize(sig, argBegin)
    70  	cur = m.insertStackBoundsCheck(goCallStackSize+frameInfoSize, cur)
    71  
    72  	originalArg0Reg := x17VReg // Caller save, so we can use it for whatever we want.
    73  	if m.currentABI.alignedArgResultStackSlotSize() > 0 {
    74  		// At this point, SP points to `ReturnAddress`, so add 16 to get the original arg 0 slot.
    75  		cur = m.addsAddOrSubStackPointer(cur, originalArg0Reg, frameInfoSize, true)
    76  	}
    77  
    78  	// Save the callee saved registers.
    79  	cur = m.saveRegistersInExecutionContext(cur, calleeSavedRegistersSorted)
    80  
    81  	// Next, we need to store all the arguments to the stack in the typical Wasm stack style.
    82  	if needModuleContextPtr {
    83  		offset := wazevoapi.ExecutionContextOffsetGoFunctionCallCalleeModuleContextOpaque.I64()
    84  		if !offsetFitsInAddressModeKindRegUnsignedImm12(64, offset) {
    85  			panic("BUG: too large or un-aligned offset for goFunctionCallCalleeModuleContextOpaque in execution context")
    86  		}
    87  
    88  		// Module context is always the second argument.
    89  		moduleCtrPtr := x1VReg
    90  		store := m.allocateInstr()
    91  		amode := addressMode{kind: addressModeKindRegUnsignedImm12, rn: execCtrPtr, imm: offset}
    92  		store.asStore(operandNR(moduleCtrPtr), amode, 64)
    93  		cur = linkInstr(cur, store)
    94  	}
    95  
    96  	// Advances the stack pointer.
    97  	cur = m.addsAddOrSubStackPointer(cur, spVReg, goCallStackSize, false)
    98  
    99  	// Copy the pointer to x15VReg.
   100  	arg0ret0AddrReg := x15VReg // Caller save, so we can use it for whatever we want.
   101  	copySp := m.allocateInstr()
   102  	copySp.asMove64(arg0ret0AddrReg, spVReg)
   103  	cur = linkInstr(cur, copySp)
   104  
   105  	for i := range abi.args[argBegin:] {
   106  		arg := &abi.args[argBegin+i]
   107  		store := m.allocateInstr()
   108  		var v regalloc.VReg
   109  		if arg.Kind == backend.ABIArgKindReg {
   110  			v = arg.Reg
   111  		} else {
   112  			cur, v = m.goFunctionCallLoadStackArg(cur, originalArg0Reg, arg,
   113  				// Caller save, so we can use it for whatever we want.
   114  				x11VReg, v11VReg)
   115  		}
   116  
   117  		var sizeInBits byte
   118  		if arg.Type == ssa.TypeV128 {
   119  			sizeInBits = 128
   120  		} else {
   121  			sizeInBits = 64
   122  		}
   123  		store.asStore(operandNR(v),
   124  			addressMode{
   125  				kind: addressModeKindPostIndex,
   126  				rn:   arg0ret0AddrReg, imm: int64(sizeInBits / 8),
   127  			}, sizeInBits)
   128  		cur = linkInstr(cur, store)
   129  	}
   130  
   131  	// Finally, now that we've advanced SP to arg[0]/ret[0], we allocate `frame_size + sliceSize`.
   132  	var frameSizeReg, sliceSizeReg regalloc.VReg
   133  	if goCallStackSize > 0 {
   134  		cur = m.lowerConstantI64AndInsert(cur, tmpRegVReg, goCallStackSize)
   135  		frameSizeReg = tmpRegVReg
   136  		cur = m.lowerConstantI64AndInsert(cur, x16VReg, sliceSizeInBytes/8)
   137  		sliceSizeReg = x16VReg
   138  	} else {
   139  		frameSizeReg = xzrVReg
   140  		sliceSizeReg = xzrVReg
   141  	}
   142  	_amode := addressModePreOrPostIndex(spVReg, -16, true)
   143  	storeP := m.allocateInstr()
   144  	storeP.asStorePair64(frameSizeReg, sliceSizeReg, _amode)
   145  	cur = linkInstr(cur, storeP)
   146  
   147  	// Set the exit status on the execution context.
   148  	cur = m.setExitCode(cur, x0VReg, exitCode)
   149  
   150  	// Save the current stack pointer.
   151  	cur = m.saveCurrentStackPointer(cur, x0VReg)
   152  
   153  	// Exit the execution.
   154  	cur = m.storeReturnAddressAndExit(cur)
   155  
   156  	// After the call, we need to restore the callee saved registers.
   157  	cur = m.restoreRegistersInExecutionContext(cur, calleeSavedRegistersSorted)
   158  
   159  	// Get the pointer to the arg[0]/ret[0]: We need to skip `frame_size + sliceSize`.
   160  	if len(abi.rets) > 0 {
   161  		cur = m.addsAddOrSubStackPointer(cur, arg0ret0AddrReg, frameInfoSize, true)
   162  	}
   163  
   164  	// Advances the SP so that it points to `ReturnAddress`.
   165  	cur = m.addsAddOrSubStackPointer(cur, spVReg, frameInfoSize+goCallStackSize, true)
   166  	ldr := m.allocateInstr()
   167  	// And load the return address.
   168  	ldr.asULoad(operandNR(lrVReg),
   169  		addressModePreOrPostIndex(spVReg, 16 /* stack pointer must be 16-byte aligned. */, false /* increment after loads */), 64)
   170  	cur = linkInstr(cur, ldr)
   171  
   172  	originalRet0Reg := x17VReg // Caller save, so we can use it for whatever we want.
   173  	if m.currentABI.retStackSize > 0 {
   174  		cur = m.addsAddOrSubStackPointer(cur, originalRet0Reg, m.currentABI.argStackSize, true)
   175  	}
   176  
   177  	// Make the SP point to the original address (above the result slot).
   178  	if s := m.currentABI.alignedArgResultStackSlotSize(); s > 0 {
   179  		cur = m.addsAddOrSubStackPointer(cur, spVReg, s, true)
   180  	}
   181  
   182  	for i := range abi.rets {
   183  		r := &abi.rets[i]
   184  		if r.Kind == backend.ABIArgKindReg {
   185  			loadIntoReg := m.allocateInstr()
   186  			mode := addressMode{kind: addressModeKindPostIndex, rn: arg0ret0AddrReg}
   187  			switch r.Type {
   188  			case ssa.TypeI32:
   189  				mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   190  				loadIntoReg.asULoad(operandNR(r.Reg), mode, 32)
   191  			case ssa.TypeI64:
   192  				mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   193  				loadIntoReg.asULoad(operandNR(r.Reg), mode, 64)
   194  			case ssa.TypeF32:
   195  				mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   196  				loadIntoReg.asFpuLoad(operandNR(r.Reg), mode, 32)
   197  			case ssa.TypeF64:
   198  				mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   199  				loadIntoReg.asFpuLoad(operandNR(r.Reg), mode, 64)
   200  			case ssa.TypeV128:
   201  				mode.imm = 16
   202  				loadIntoReg.asFpuLoad(operandNR(r.Reg), mode, 128)
   203  			default:
   204  				panic("TODO")
   205  			}
   206  			cur = linkInstr(cur, loadIntoReg)
   207  		} else {
   208  			// First we need to load the value to a temporary just like ^^.
   209  			intTmp, floatTmp := x11VReg, v11VReg
   210  			loadIntoTmpReg := m.allocateInstr()
   211  			mode := addressMode{kind: addressModeKindPostIndex, rn: arg0ret0AddrReg}
   212  			var resultReg regalloc.VReg
   213  			switch r.Type {
   214  			case ssa.TypeI32:
   215  				mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   216  				loadIntoTmpReg.asULoad(operandNR(intTmp), mode, 32)
   217  				resultReg = intTmp
   218  			case ssa.TypeI64:
   219  				mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   220  				loadIntoTmpReg.asULoad(operandNR(intTmp), mode, 64)
   221  				resultReg = intTmp
   222  			case ssa.TypeF32:
   223  				mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   224  				loadIntoTmpReg.asFpuLoad(operandNR(floatTmp), mode, 32)
   225  				resultReg = floatTmp
   226  			case ssa.TypeF64:
   227  				mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   228  				loadIntoTmpReg.asFpuLoad(operandNR(floatTmp), mode, 64)
   229  				resultReg = floatTmp
   230  			case ssa.TypeV128:
   231  				mode.imm = 16
   232  				loadIntoTmpReg.asFpuLoad(operandNR(floatTmp), mode, 128)
   233  				resultReg = floatTmp
   234  			default:
   235  				panic("TODO")
   236  			}
   237  			cur = linkInstr(cur, loadIntoTmpReg)
   238  			cur = m.goFunctionCallStoreStackResult(cur, originalRet0Reg, r, resultReg)
   239  		}
   240  	}
   241  
   242  	ret := m.allocateInstr()
   243  	ret.asRet(nil)
   244  	linkInstr(cur, ret)
   245  
   246  	m.encode(m.executableContext.RootInstr)
   247  	return m.compiler.Buf()
   248  }
   249  
   250  func (m *machine) saveRegistersInExecutionContext(cur *instruction, regs []regalloc.VReg) *instruction {
   251  	offset := wazevoapi.ExecutionContextOffsetSavedRegistersBegin.I64()
   252  	for _, v := range regs {
   253  		store := m.allocateInstr()
   254  		var sizeInBits byte
   255  		switch v.RegType() {
   256  		case regalloc.RegTypeInt:
   257  			sizeInBits = 64
   258  		case regalloc.RegTypeFloat:
   259  			sizeInBits = 128
   260  		}
   261  		store.asStore(operandNR(v),
   262  			addressMode{
   263  				kind: addressModeKindRegUnsignedImm12,
   264  				// Execution context is always the first argument.
   265  				rn: x0VReg, imm: offset,
   266  			}, sizeInBits)
   267  		store.prev = cur
   268  		cur.next = store
   269  		cur = store
   270  		offset += 16 // Imm12 must be aligned 16 for vector regs, so we unconditionally store regs at the offset of multiple of 16.
   271  	}
   272  	return cur
   273  }
   274  
   275  func (m *machine) restoreRegistersInExecutionContext(cur *instruction, regs []regalloc.VReg) *instruction {
   276  	offset := wazevoapi.ExecutionContextOffsetSavedRegistersBegin.I64()
   277  	for _, v := range regs {
   278  		load := m.allocateInstr()
   279  		var as func(dst operand, amode addressMode, sizeInBits byte)
   280  		var sizeInBits byte
   281  		switch v.RegType() {
   282  		case regalloc.RegTypeInt:
   283  			as = load.asULoad
   284  			sizeInBits = 64
   285  		case regalloc.RegTypeFloat:
   286  			as = load.asFpuLoad
   287  			sizeInBits = 128
   288  		}
   289  		as(operandNR(v),
   290  			addressMode{
   291  				kind: addressModeKindRegUnsignedImm12,
   292  				// Execution context is always the first argument.
   293  				rn: x0VReg, imm: offset,
   294  			}, sizeInBits)
   295  		cur = linkInstr(cur, load)
   296  		offset += 16 // Imm12 must be aligned 16 for vector regs, so we unconditionally load regs at the offset of multiple of 16.
   297  	}
   298  	return cur
   299  }
   300  
   301  func (m *machine) lowerConstantI64AndInsert(cur *instruction, dst regalloc.VReg, v int64) *instruction {
   302  	exct := m.executableContext
   303  	exct.PendingInstructions = exct.PendingInstructions[:0]
   304  	m.lowerConstantI64(dst, v)
   305  	for _, instr := range exct.PendingInstructions {
   306  		cur = linkInstr(cur, instr)
   307  	}
   308  	return cur
   309  }
   310  
   311  func (m *machine) lowerConstantI32AndInsert(cur *instruction, dst regalloc.VReg, v int32) *instruction {
   312  	exct := m.executableContext
   313  	exct.PendingInstructions = exct.PendingInstructions[:0]
   314  	m.lowerConstantI32(dst, v)
   315  	for _, instr := range exct.PendingInstructions {
   316  		cur = linkInstr(cur, instr)
   317  	}
   318  	return cur
   319  }
   320  
   321  func (m *machine) setExitCode(cur *instruction, execCtr regalloc.VReg, exitCode wazevoapi.ExitCode) *instruction {
   322  	constReg := x17VReg // caller-saved, so we can use it.
   323  	cur = m.lowerConstantI32AndInsert(cur, constReg, int32(exitCode))
   324  
   325  	// Set the exit status on the execution context.
   326  	setExistStatus := m.allocateInstr()
   327  	setExistStatus.asStore(operandNR(constReg),
   328  		addressMode{
   329  			kind: addressModeKindRegUnsignedImm12,
   330  			rn:   execCtr, imm: wazevoapi.ExecutionContextOffsetExitCodeOffset.I64(),
   331  		}, 32)
   332  	cur = linkInstr(cur, setExistStatus)
   333  	return cur
   334  }
   335  
   336  func (m *machine) storeReturnAddressAndExit(cur *instruction) *instruction {
   337  	// Read the return address into tmp, and store it in the execution context.
   338  	adr := m.allocateInstr()
   339  	adr.asAdr(tmpRegVReg, exitSequenceSize+8)
   340  	cur = linkInstr(cur, adr)
   341  
   342  	storeReturnAddr := m.allocateInstr()
   343  	storeReturnAddr.asStore(operandNR(tmpRegVReg),
   344  		addressMode{
   345  			kind: addressModeKindRegUnsignedImm12,
   346  			// Execution context is always the first argument.
   347  			rn: x0VReg, imm: wazevoapi.ExecutionContextOffsetGoCallReturnAddress.I64(),
   348  		}, 64)
   349  	cur = linkInstr(cur, storeReturnAddr)
   350  
   351  	// Exit the execution.
   352  	trapSeq := m.allocateInstr()
   353  	trapSeq.asExitSequence(x0VReg)
   354  	cur = linkInstr(cur, trapSeq)
   355  	return cur
   356  }
   357  
   358  func (m *machine) saveCurrentStackPointer(cur *instruction, execCtr regalloc.VReg) *instruction {
   359  	// Save the current stack pointer:
   360  	// 	mov tmp, sp,
   361  	// 	str tmp, [exec_ctx, #stackPointerBeforeGoCall]
   362  	movSp := m.allocateInstr()
   363  	movSp.asMove64(tmpRegVReg, spVReg)
   364  	cur = linkInstr(cur, movSp)
   365  
   366  	strSp := m.allocateInstr()
   367  	strSp.asStore(operandNR(tmpRegVReg),
   368  		addressMode{
   369  			kind: addressModeKindRegUnsignedImm12,
   370  			rn:   execCtr, imm: wazevoapi.ExecutionContextOffsetStackPointerBeforeGoCall.I64(),
   371  		}, 64)
   372  	cur = linkInstr(cur, strSp)
   373  	return cur
   374  }
   375  
   376  // goFunctionCallRequiredStackSize returns the size of the stack required for the Go function call.
   377  func goFunctionCallRequiredStackSize(sig *ssa.Signature, argBegin int) (ret, retUnaligned int64) {
   378  	var paramNeededInBytes, resultNeededInBytes int64
   379  	for _, p := range sig.Params[argBegin:] {
   380  		s := int64(p.Size())
   381  		if s < 8 {
   382  			s = 8 // We use uint64 for all basic types, except SIMD v128.
   383  		}
   384  		paramNeededInBytes += s
   385  	}
   386  	for _, r := range sig.Results {
   387  		s := int64(r.Size())
   388  		if s < 8 {
   389  			s = 8 // We use uint64 for all basic types, except SIMD v128.
   390  		}
   391  		resultNeededInBytes += s
   392  	}
   393  
   394  	if paramNeededInBytes > resultNeededInBytes {
   395  		ret = paramNeededInBytes
   396  	} else {
   397  		ret = resultNeededInBytes
   398  	}
   399  	retUnaligned = ret
   400  	// Align to 16 bytes.
   401  	ret = (ret + 15) &^ 15
   402  	return
   403  }
   404  
   405  func (m *machine) goFunctionCallLoadStackArg(cur *instruction, originalArg0Reg regalloc.VReg, arg *backend.ABIArg, intVReg, floatVReg regalloc.VReg) (*instruction, regalloc.VReg) {
   406  	load := m.allocateInstr()
   407  	var result regalloc.VReg
   408  	mode := addressMode{kind: addressModeKindPostIndex, rn: originalArg0Reg}
   409  	switch arg.Type {
   410  	case ssa.TypeI32:
   411  		mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   412  		load.asULoad(operandNR(intVReg), mode, 32)
   413  		result = intVReg
   414  	case ssa.TypeI64:
   415  		mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   416  		load.asULoad(operandNR(intVReg), mode, 64)
   417  		result = intVReg
   418  	case ssa.TypeF32:
   419  		mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   420  		load.asFpuLoad(operandNR(floatVReg), mode, 32)
   421  		result = floatVReg
   422  	case ssa.TypeF64:
   423  		mode.imm = 8 // We use uint64 for all basic types, except SIMD v128.
   424  		load.asFpuLoad(operandNR(floatVReg), mode, 64)
   425  		result = floatVReg
   426  	case ssa.TypeV128:
   427  		mode.imm = 16
   428  		load.asFpuLoad(operandNR(floatVReg), mode, 128)
   429  		result = floatVReg
   430  	default:
   431  		panic("TODO")
   432  	}
   433  
   434  	cur = linkInstr(cur, load)
   435  	return cur, result
   436  }
   437  
   438  func (m *machine) goFunctionCallStoreStackResult(cur *instruction, originalRet0Reg regalloc.VReg, result *backend.ABIArg, resultVReg regalloc.VReg) *instruction {
   439  	store := m.allocateInstr()
   440  	mode := addressMode{kind: addressModeKindPostIndex, rn: originalRet0Reg}
   441  	var sizeInBits byte
   442  	switch result.Type {
   443  	case ssa.TypeI32, ssa.TypeF32:
   444  		mode.imm = 8
   445  		sizeInBits = 32
   446  	case ssa.TypeI64, ssa.TypeF64:
   447  		mode.imm = 8
   448  		sizeInBits = 64
   449  	case ssa.TypeV128:
   450  		mode.imm = 16
   451  		sizeInBits = 128
   452  	default:
   453  		panic("TODO")
   454  	}
   455  	store.asStore(operandNR(resultVReg), mode, sizeInBits)
   456  	return linkInstr(cur, store)
   457  }