github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/compiler/compiler_value_location.go (about)

     1  package compiler
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/bananabytelabs/wazero/internal/asm"
     8  	"github.com/bananabytelabs/wazero/internal/wasm"
     9  )
    10  
    11  var (
    12  	// unreservedGeneralPurposeRegisters contains unreserved general purpose registers of integer type.
    13  	unreservedGeneralPurposeRegisters []asm.Register
    14  
    15  	// unreservedVectorRegisters contains unreserved vector registers.
    16  	unreservedVectorRegisters []asm.Register
    17  )
    18  
    19  func isGeneralPurposeRegister(r asm.Register) bool {
    20  	return unreservedGeneralPurposeRegisters[0] <= r && r <= unreservedGeneralPurposeRegisters[len(unreservedGeneralPurposeRegisters)-1]
    21  }
    22  
    23  func isVectorRegister(r asm.Register) bool {
    24  	return unreservedVectorRegisters[0] <= r && r <= unreservedVectorRegisters[len(unreservedVectorRegisters)-1]
    25  }
    26  
    27  // runtimeValueLocation corresponds to each variable pushed onto the wazeroir (virtual) stack,
    28  // and it has the information about where it exists in the physical machine.
    29  // It might exist in registers, or maybe on in the non-virtual physical stack allocated in memory.
    30  type runtimeValueLocation struct {
    31  	valueType runtimeValueType
    32  	// register is set to asm.NilRegister if the value is stored in the memory stack.
    33  	register asm.Register
    34  	// conditionalRegister is set to conditionalRegisterStateUnset if the value is not on the conditional register.
    35  	conditionalRegister asm.ConditionalRegisterState
    36  	// stackPointer is the location of this value in the memory stack at runtime,
    37  	stackPointer uint64
    38  }
    39  
    40  func (v *runtimeValueLocation) getRegisterType() (ret registerType) {
    41  	switch v.valueType {
    42  	case runtimeValueTypeI32, runtimeValueTypeI64:
    43  		ret = registerTypeGeneralPurpose
    44  	case runtimeValueTypeF32, runtimeValueTypeF64,
    45  		runtimeValueTypeV128Lo, runtimeValueTypeV128Hi:
    46  		ret = registerTypeVector
    47  	default:
    48  		panic("BUG")
    49  	}
    50  	return
    51  }
    52  
    53  type runtimeValueType byte
    54  
    55  const (
    56  	runtimeValueTypeNone runtimeValueType = iota
    57  	runtimeValueTypeI32
    58  	runtimeValueTypeI64
    59  	runtimeValueTypeF32
    60  	runtimeValueTypeF64
    61  	runtimeValueTypeV128Lo
    62  	runtimeValueTypeV128Hi
    63  )
    64  
    65  func (r runtimeValueType) String() (ret string) {
    66  	switch r {
    67  	case runtimeValueTypeI32:
    68  		ret = "i32"
    69  	case runtimeValueTypeI64:
    70  		ret = "i64"
    71  	case runtimeValueTypeF32:
    72  		ret = "f32"
    73  	case runtimeValueTypeF64:
    74  		ret = "f64"
    75  	case runtimeValueTypeV128Lo:
    76  		ret = "v128.lo"
    77  	case runtimeValueTypeV128Hi:
    78  		ret = "v128.hi"
    79  	}
    80  	return
    81  }
    82  
    83  func (v *runtimeValueLocation) setRegister(reg asm.Register) {
    84  	v.register = reg
    85  	v.conditionalRegister = asm.ConditionalRegisterStateUnset
    86  }
    87  
    88  func (v *runtimeValueLocation) onRegister() bool {
    89  	return v.register != asm.NilRegister && v.conditionalRegister == asm.ConditionalRegisterStateUnset
    90  }
    91  
    92  func (v *runtimeValueLocation) onStack() bool {
    93  	return v.register == asm.NilRegister && v.conditionalRegister == asm.ConditionalRegisterStateUnset
    94  }
    95  
    96  func (v *runtimeValueLocation) onConditionalRegister() bool {
    97  	return v.conditionalRegister != asm.ConditionalRegisterStateUnset
    98  }
    99  
   100  func (v *runtimeValueLocation) String() string {
   101  	var location string
   102  	if v.onStack() {
   103  		location = fmt.Sprintf("stack(%d)", v.stackPointer)
   104  	} else if v.onConditionalRegister() {
   105  		location = fmt.Sprintf("conditional(%d)", v.conditionalRegister)
   106  	} else if v.onRegister() {
   107  		location = fmt.Sprintf("register(%s)", registerNameFn(v.register))
   108  	}
   109  	return fmt.Sprintf("{type=%s,location=%s}", v.valueType, location)
   110  }
   111  
   112  func newRuntimeValueLocationStack() runtimeValueLocationStack {
   113  	return runtimeValueLocationStack{
   114  		unreservedVectorRegisters:         unreservedVectorRegisters,
   115  		unreservedGeneralPurposeRegisters: unreservedGeneralPurposeRegisters,
   116  	}
   117  }
   118  
   119  // runtimeValueLocationStack represents the wazeroir virtual stack
   120  // where each item holds the location information about where it exists
   121  // on the physical machine at runtime.
   122  //
   123  // Notably this is only used in the compilation phase, not runtime,
   124  // and we change the state of this struct at every wazeroir operation we compile.
   125  // In this way, we can see where the operands of an operation (for example,
   126  // two variables for wazeroir add operation.) exist and check the necessity for
   127  // moving the variable to registers to perform actual CPU instruction
   128  // to achieve wazeroir's add operation.
   129  type runtimeValueLocationStack struct {
   130  	// stack holds all the variables.
   131  	stack []runtimeValueLocation
   132  	// sp is the current stack pointer.
   133  	sp uint64
   134  	// usedRegisters is the bit map to track the used registers.
   135  	usedRegisters usedRegistersMask
   136  	// stackPointerCeil tracks max(.sp) across the lifespan of this struct.
   137  	stackPointerCeil uint64
   138  	// unreservedGeneralPurposeRegisters and unreservedVectorRegisters hold
   139  	// architecture dependent unreserved register list.
   140  	unreservedGeneralPurposeRegisters, unreservedVectorRegisters []asm.Register
   141  }
   142  
   143  func (v *runtimeValueLocationStack) reset() {
   144  	stack := v.stack[:0]
   145  	*v = runtimeValueLocationStack{
   146  		unreservedVectorRegisters:         unreservedVectorRegisters,
   147  		unreservedGeneralPurposeRegisters: unreservedGeneralPurposeRegisters,
   148  		stack:                             stack,
   149  	}
   150  }
   151  
   152  func (v *runtimeValueLocationStack) String() string {
   153  	var stackStr []string
   154  	for i := uint64(0); i < v.sp; i++ {
   155  		stackStr = append(stackStr, v.stack[i].String())
   156  	}
   157  	usedRegisters := v.usedRegisters.list()
   158  	return fmt.Sprintf("sp=%d, stack=[%s], used_registers=[%s]", v.sp, strings.Join(stackStr, ","), strings.Join(usedRegisters, ","))
   159  }
   160  
   161  // cloneFrom clones the values on `from` into self except for the slice of .stack field.
   162  // The content on .stack will be copied from the origin to self, and grow the underlying slice
   163  // if necessary.
   164  func (v *runtimeValueLocationStack) cloneFrom(from runtimeValueLocationStack) {
   165  	// Assigns the same values for fields except for the stack which we want to reuse.
   166  	prev := v.stack
   167  	*v = from
   168  	v.stack = prev[:cap(prev)] // Expand the length to the capacity so that we can minimize "diff" below.
   169  	// Copy the content in the stack.
   170  	if diff := int(from.sp) - len(v.stack); diff > 0 {
   171  		v.stack = append(v.stack, make([]runtimeValueLocation, diff)...)
   172  	}
   173  	copy(v.stack, from.stack[:from.sp])
   174  }
   175  
   176  // pushRuntimeValueLocationOnRegister creates a new runtimeValueLocation with a given register and pushes onto
   177  // the location stack.
   178  func (v *runtimeValueLocationStack) pushRuntimeValueLocationOnRegister(reg asm.Register, vt runtimeValueType) (loc *runtimeValueLocation) {
   179  	loc = v.push(reg, asm.ConditionalRegisterStateUnset)
   180  	loc.valueType = vt
   181  	return
   182  }
   183  
   184  // pushRuntimeValueLocationOnRegister creates a new runtimeValueLocation and pushes onto the location stack.
   185  func (v *runtimeValueLocationStack) pushRuntimeValueLocationOnStack() (loc *runtimeValueLocation) {
   186  	loc = v.push(asm.NilRegister, asm.ConditionalRegisterStateUnset)
   187  	loc.valueType = runtimeValueTypeNone
   188  	return
   189  }
   190  
   191  // pushRuntimeValueLocationOnRegister creates a new runtimeValueLocation with a given conditional register state
   192  // and pushes onto the location stack.
   193  func (v *runtimeValueLocationStack) pushRuntimeValueLocationOnConditionalRegister(state asm.ConditionalRegisterState) (loc *runtimeValueLocation) {
   194  	loc = v.push(asm.NilRegister, state)
   195  	loc.valueType = runtimeValueTypeI32
   196  	return
   197  }
   198  
   199  // push a runtimeValueLocation onto the stack.
   200  func (v *runtimeValueLocationStack) push(reg asm.Register, conditionalRegister asm.ConditionalRegisterState) (ret *runtimeValueLocation) {
   201  	if v.sp >= uint64(len(v.stack)) {
   202  		// This case we need to grow the stack capacity by appending the item,
   203  		// rather than indexing.
   204  		v.stack = append(v.stack, runtimeValueLocation{})
   205  	}
   206  	ret = &v.stack[v.sp]
   207  	ret.register, ret.conditionalRegister, ret.stackPointer = reg, conditionalRegister, v.sp
   208  	v.sp++
   209  	// stackPointerCeil must be set after sp is incremented since
   210  	// we skip the stack grow if len(stack) >= basePointer+stackPointerCeil.
   211  	if v.sp > v.stackPointerCeil {
   212  		v.stackPointerCeil = v.sp
   213  	}
   214  	return
   215  }
   216  
   217  func (v *runtimeValueLocationStack) pop() (loc *runtimeValueLocation) {
   218  	v.sp--
   219  	loc = &v.stack[v.sp]
   220  	return
   221  }
   222  
   223  func (v *runtimeValueLocationStack) popV128() (loc *runtimeValueLocation) {
   224  	v.sp -= 2
   225  	loc = &v.stack[v.sp]
   226  	return
   227  }
   228  
   229  func (v *runtimeValueLocationStack) peek() (loc *runtimeValueLocation) {
   230  	loc = &v.stack[v.sp-1]
   231  	return
   232  }
   233  
   234  func (v *runtimeValueLocationStack) releaseRegister(loc *runtimeValueLocation) {
   235  	v.markRegisterUnused(loc.register)
   236  	loc.register = asm.NilRegister
   237  	loc.conditionalRegister = asm.ConditionalRegisterStateUnset
   238  }
   239  
   240  func (v *runtimeValueLocationStack) markRegisterUnused(regs ...asm.Register) {
   241  	for _, reg := range regs {
   242  		v.usedRegisters.remove(reg)
   243  	}
   244  }
   245  
   246  func (v *runtimeValueLocationStack) markRegisterUsed(regs ...asm.Register) {
   247  	for _, reg := range regs {
   248  		v.usedRegisters.add(reg)
   249  	}
   250  }
   251  
   252  type registerType byte
   253  
   254  const (
   255  	registerTypeGeneralPurpose registerType = iota
   256  	// registerTypeVector represents a vector register which can be used for either scalar float
   257  	// operation or SIMD vector operation depending on the instruction by which the register is used.
   258  	//
   259  	// Note: In normal assembly language, scalar float and vector register have different notations as
   260  	// Vn is for vectors and Qn is for scalar floats on arm64 for example. But on physical hardware,
   261  	// they are placed on the same locations. (Qn means the lower 64-bit of Vn vector register on arm64).
   262  	//
   263  	// In wazero, for the sake of simplicity in the register allocation, we intentionally conflate these two types
   264  	// and delegate the decision to the assembler which is aware of the instruction types for which these registers are used.
   265  	registerTypeVector
   266  )
   267  
   268  func (tp registerType) String() (ret string) {
   269  	switch tp {
   270  	case registerTypeGeneralPurpose:
   271  		ret = "int"
   272  	case registerTypeVector:
   273  		ret = "vector"
   274  	}
   275  	return
   276  }
   277  
   278  // takeFreeRegister searches for unused registers. Any found are marked used and returned.
   279  func (v *runtimeValueLocationStack) takeFreeRegister(tp registerType) (reg asm.Register, found bool) {
   280  	var targetRegs []asm.Register
   281  	switch tp {
   282  	case registerTypeVector:
   283  		targetRegs = v.unreservedVectorRegisters
   284  	case registerTypeGeneralPurpose:
   285  		targetRegs = v.unreservedGeneralPurposeRegisters
   286  	}
   287  	for _, candidate := range targetRegs {
   288  		if v.usedRegisters.exist(candidate) {
   289  			continue
   290  		}
   291  		return candidate, true
   292  	}
   293  	return 0, false
   294  }
   295  
   296  // Search through the stack, and steal the register from the last used
   297  // variable on the stack.
   298  func (v *runtimeValueLocationStack) takeStealTargetFromUsedRegister(tp registerType) (*runtimeValueLocation, bool) {
   299  	for i := uint64(0); i < v.sp; i++ {
   300  		loc := &v.stack[i]
   301  		if loc.onRegister() {
   302  			switch tp {
   303  			case registerTypeVector:
   304  				if loc.valueType == runtimeValueTypeV128Hi {
   305  					panic("BUG: V128Hi must be above the corresponding V128Lo")
   306  				}
   307  				if isVectorRegister(loc.register) {
   308  					return loc, true
   309  				}
   310  			case registerTypeGeneralPurpose:
   311  				if isGeneralPurposeRegister(loc.register) {
   312  					return loc, true
   313  				}
   314  			}
   315  		}
   316  	}
   317  	return nil, false
   318  }
   319  
   320  // init sets up the runtimeValueLocationStack which reflects the state of
   321  // the stack at the beginning of the function.
   322  //
   323  // See the diagram in callEngine.stack.
   324  func (v *runtimeValueLocationStack) init(sig *wasm.FunctionType) {
   325  	for _, t := range sig.Params {
   326  		loc := v.pushRuntimeValueLocationOnStack()
   327  		switch t {
   328  		case wasm.ValueTypeI32:
   329  			loc.valueType = runtimeValueTypeI32
   330  		case wasm.ValueTypeI64, wasm.ValueTypeFuncref, wasm.ValueTypeExternref:
   331  			loc.valueType = runtimeValueTypeI64
   332  		case wasm.ValueTypeF32:
   333  			loc.valueType = runtimeValueTypeF32
   334  		case wasm.ValueTypeF64:
   335  			loc.valueType = runtimeValueTypeF64
   336  		case wasm.ValueTypeV128:
   337  			loc.valueType = runtimeValueTypeV128Lo
   338  			hi := v.pushRuntimeValueLocationOnStack()
   339  			hi.valueType = runtimeValueTypeV128Hi
   340  		default:
   341  			panic("BUG")
   342  		}
   343  	}
   344  
   345  	// If the len(results) > len(args), the slots for all results are reserved after
   346  	// arguments, so we reflect that into the location stack.
   347  	for i := 0; i < sig.ResultNumInUint64-sig.ParamNumInUint64; i++ {
   348  		_ = v.pushRuntimeValueLocationOnStack()
   349  	}
   350  
   351  	// Then push the control frame fields.
   352  	for i := 0; i < callFrameDataSizeInUint64; i++ {
   353  		loc := v.pushRuntimeValueLocationOnStack()
   354  		loc.valueType = runtimeValueTypeI64
   355  	}
   356  }
   357  
   358  // getCallFrameLocations returns each field of callFrame's runtime location.
   359  //
   360  // See the diagram in callEngine.stack.
   361  func (v *runtimeValueLocationStack) getCallFrameLocations(sig *wasm.FunctionType) (
   362  	returnAddress, callerStackBasePointerInBytes, callerFunction *runtimeValueLocation,
   363  ) {
   364  	offset := callFrameOffset(sig)
   365  	return &v.stack[offset], &v.stack[offset+1], &v.stack[offset+2]
   366  }
   367  
   368  // pushCallFrame pushes a call frame's runtime locations onto the stack assuming that
   369  // the function call parameters are already pushed there.
   370  //
   371  // See the diagram in callEngine.stack.
   372  func (v *runtimeValueLocationStack) pushCallFrame(callTargetFunctionType *wasm.FunctionType) (
   373  	returnAddress, callerStackBasePointerInBytes, callerFunction *runtimeValueLocation,
   374  ) {
   375  	// If len(results) > len(args), we reserve the slots for the results below the call frame.
   376  	reservedSlotsBeforeCallFrame := callTargetFunctionType.ResultNumInUint64 - callTargetFunctionType.ParamNumInUint64
   377  	for i := 0; i < reservedSlotsBeforeCallFrame; i++ {
   378  		v.pushRuntimeValueLocationOnStack()
   379  	}
   380  
   381  	// Push the runtime location for each field of callFrame struct. Note that each of them has
   382  	// uint64 type, and therefore must be treated as runtimeValueTypeI64.
   383  
   384  	// callFrame.returnAddress
   385  	returnAddress = v.pushRuntimeValueLocationOnStack()
   386  	returnAddress.valueType = runtimeValueTypeI64
   387  	// callFrame.returnStackBasePointerInBytes
   388  	callerStackBasePointerInBytes = v.pushRuntimeValueLocationOnStack()
   389  	callerStackBasePointerInBytes.valueType = runtimeValueTypeI64
   390  	// callFrame.function
   391  	callerFunction = v.pushRuntimeValueLocationOnStack()
   392  	callerFunction.valueType = runtimeValueTypeI64
   393  	return
   394  }
   395  
   396  // usedRegistersMask tracks the used registers in its bits.
   397  type usedRegistersMask uint64
   398  
   399  // add adds the given `r` to the mask.
   400  func (u *usedRegistersMask) add(r asm.Register) {
   401  	*u = *u | (1 << registerMaskShift(r))
   402  }
   403  
   404  // remove drops the given `r` from the mask.
   405  func (u *usedRegistersMask) remove(r asm.Register) {
   406  	*u = *u & ^(1 << registerMaskShift(r))
   407  }
   408  
   409  // exist returns true if the given `r` is used.
   410  func (u *usedRegistersMask) exist(r asm.Register) bool {
   411  	shift := registerMaskShift(r)
   412  	return (*u & (1 << shift)) > 0
   413  }
   414  
   415  // list returns the list of debug string of used registers.
   416  // Only used for debugging and testing.
   417  func (u *usedRegistersMask) list() (ret []string) {
   418  	mask := *u
   419  	for i := 0; i < 64; i++ {
   420  		if mask&(1<<i) > 0 {
   421  			ret = append(ret, registerNameFn(registerFromMaskShift(i)))
   422  		}
   423  	}
   424  	return
   425  }