github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/compiler/compiler_value_location_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/bananabytelabs/wazero/internal/asm"
     7  	"github.com/bananabytelabs/wazero/internal/testing/require"
     8  	"github.com/bananabytelabs/wazero/internal/wasm"
     9  )
    10  
    11  func Test_isIntRegister(t *testing.T) {
    12  	for _, r := range unreservedGeneralPurposeRegisters {
    13  		require.True(t, isGeneralPurposeRegister(r))
    14  	}
    15  }
    16  
    17  func Test_isVectorRegister(t *testing.T) {
    18  	for _, r := range unreservedVectorRegisters {
    19  		require.True(t, isVectorRegister(r))
    20  	}
    21  }
    22  
    23  func TestRuntimeValueLocationStack_basic(t *testing.T) {
    24  	s := newRuntimeValueLocationStack()
    25  	// Push stack value.
    26  	loc := s.pushRuntimeValueLocationOnStack()
    27  	require.Equal(t, uint64(1), s.sp)
    28  	require.Equal(t, uint64(0), loc.stackPointer)
    29  	// Push the register value.
    30  	tmpReg := unreservedGeneralPurposeRegisters[0]
    31  	loc = s.pushRuntimeValueLocationOnRegister(tmpReg, runtimeValueTypeI64)
    32  	require.Equal(t, uint64(2), s.sp)
    33  	require.Equal(t, uint64(1), loc.stackPointer)
    34  	require.Equal(t, tmpReg, loc.register)
    35  	require.Equal(t, loc.valueType, runtimeValueTypeI64)
    36  	// markRegisterUsed.
    37  	tmpReg2 := unreservedGeneralPurposeRegisters[1]
    38  	s.markRegisterUsed(tmpReg2)
    39  	require.True(t, s.usedRegisters.exist(tmpReg2))
    40  	// releaseRegister.
    41  	s.releaseRegister(loc)
    42  	require.False(t, s.usedRegisters.exist(loc.register))
    43  	require.Equal(t, asm.NilRegister, loc.register)
    44  	// Check the max stack pointer.
    45  	for i := 0; i < 1000; i++ {
    46  		s.pushRuntimeValueLocationOnStack()
    47  	}
    48  	for i := 0; i < 1000; i++ {
    49  		s.pop()
    50  	}
    51  	require.Equal(t, uint64(1002), s.stackPointerCeil)
    52  }
    53  
    54  func TestRuntimeValueLocationStack_takeFreeRegister(t *testing.T) {
    55  	s := newRuntimeValueLocationStack()
    56  	// For int registers.
    57  	r, ok := s.takeFreeRegister(registerTypeGeneralPurpose)
    58  	require.True(t, ok)
    59  	require.True(t, isGeneralPurposeRegister(r))
    60  	// Mark all the int registers used.
    61  	for _, r := range unreservedGeneralPurposeRegisters {
    62  		s.markRegisterUsed(r)
    63  	}
    64  	// Now we cannot take free ones for int.
    65  	_, ok = s.takeFreeRegister(registerTypeGeneralPurpose)
    66  	require.False(t, ok)
    67  	// But we still should be able to take float regs.
    68  	r, ok = s.takeFreeRegister(registerTypeVector)
    69  	require.True(t, ok)
    70  	require.True(t, isVectorRegister(r))
    71  	// Mark all the float registers used.
    72  	for _, r := range unreservedVectorRegisters {
    73  		s.markRegisterUsed(r)
    74  	}
    75  	// Now we cannot take free ones for floats.
    76  	_, ok = s.takeFreeRegister(registerTypeVector)
    77  	require.False(t, ok)
    78  }
    79  
    80  func TestRuntimeValueLocationStack_takeStealTargetFromUsedRegister(t *testing.T) {
    81  	s := newRuntimeValueLocationStack()
    82  	intReg := unreservedGeneralPurposeRegisters[0]
    83  	floatReg := unreservedVectorRegisters[0]
    84  	intLocation := s.push(intReg, asm.ConditionalRegisterStateUnset)
    85  	floatLocation := s.push(floatReg, asm.ConditionalRegisterStateUnset)
    86  	// Take for float.
    87  	target, ok := s.takeStealTargetFromUsedRegister(registerTypeVector)
    88  	require.True(t, ok)
    89  	require.Equal(t, floatLocation, target)
    90  	// Take for ints.
    91  	target, ok = s.takeStealTargetFromUsedRegister(registerTypeGeneralPurpose)
    92  	require.True(t, ok)
    93  	require.Equal(t, intLocation, target)
    94  	// Pop float value.
    95  	popped := s.pop()
    96  	require.Equal(t, floatLocation, popped)
    97  	// Now we cannot find the steal target.
    98  	target, ok = s.takeStealTargetFromUsedRegister(registerTypeVector)
    99  	require.False(t, ok)
   100  	require.Nil(t, target)
   101  	// Pop int value.
   102  	popped = s.pop()
   103  	require.Equal(t, intLocation, popped)
   104  	// Now we cannot find the steal target.
   105  	target, ok = s.takeStealTargetFromUsedRegister(registerTypeGeneralPurpose)
   106  	require.False(t, ok)
   107  	require.Nil(t, target)
   108  }
   109  
   110  func TestRuntimeValueLocationStack_setupInitialStack(t *testing.T) {
   111  	const f32 = wasm.ValueTypeF32
   112  	tests := []struct {
   113  		name       string
   114  		sig        *wasm.FunctionType
   115  		expectedSP uint64
   116  	}{
   117  		{
   118  			name:       "no params / no results",
   119  			sig:        &wasm.FunctionType{},
   120  			expectedSP: callFrameDataSizeInUint64,
   121  		},
   122  		{
   123  			name: "no results",
   124  			sig: &wasm.FunctionType{
   125  				Params:           []wasm.ValueType{f32, f32},
   126  				ParamNumInUint64: 2,
   127  			},
   128  			expectedSP: callFrameDataSizeInUint64 + 2,
   129  		},
   130  		{
   131  			name: "no params",
   132  			sig: &wasm.FunctionType{
   133  				Results:           []wasm.ValueType{f32, f32},
   134  				ResultNumInUint64: 2,
   135  			},
   136  			expectedSP: callFrameDataSizeInUint64 + 2,
   137  		},
   138  		{
   139  			name: "params == results",
   140  			sig: &wasm.FunctionType{
   141  				Params:            []wasm.ValueType{f32, f32},
   142  				ParamNumInUint64:  2,
   143  				Results:           []wasm.ValueType{f32, f32},
   144  				ResultNumInUint64: 2,
   145  			},
   146  			expectedSP: callFrameDataSizeInUint64 + 2,
   147  		},
   148  		{
   149  			name: "params > results",
   150  			sig: &wasm.FunctionType{
   151  				Params:            []wasm.ValueType{f32, f32, f32},
   152  				ParamNumInUint64:  3,
   153  				Results:           []wasm.ValueType{f32, f32},
   154  				ResultNumInUint64: 2,
   155  			},
   156  			expectedSP: callFrameDataSizeInUint64 + 3,
   157  		},
   158  		{
   159  			name: "params <  results",
   160  			sig: &wasm.FunctionType{
   161  				Params:            []wasm.ValueType{f32},
   162  				ParamNumInUint64:  1,
   163  				Results:           []wasm.ValueType{f32, f32, f32},
   164  				ResultNumInUint64: 3,
   165  			},
   166  			expectedSP: callFrameDataSizeInUint64 + 3,
   167  		},
   168  	}
   169  
   170  	for _, tc := range tests {
   171  		tc := tc
   172  		t.Run(tc.name, func(t *testing.T) {
   173  			s := newRuntimeValueLocationStack()
   174  			s.init(tc.sig)
   175  			require.Equal(t, tc.expectedSP, s.sp)
   176  
   177  			callFrameLocations := s.stack[s.sp-callFrameDataSizeInUint64 : s.sp]
   178  			for _, loc := range callFrameLocations {
   179  				require.Equal(t, runtimeValueTypeI64, loc.valueType)
   180  			}
   181  		})
   182  	}
   183  }
   184  
   185  func TestRuntimeValueLocation_pushCallFrame(t *testing.T) {
   186  	for _, sig := range []*wasm.FunctionType{
   187  		{ParamNumInUint64: 0, ResultNumInUint64: 1},
   188  		{ParamNumInUint64: 1, ResultNumInUint64: 0},
   189  		{ParamNumInUint64: 1, ResultNumInUint64: 1},
   190  		{ParamNumInUint64: 0, ResultNumInUint64: 2},
   191  		{ParamNumInUint64: 2, ResultNumInUint64: 0},
   192  		{ParamNumInUint64: 2, ResultNumInUint64: 3},
   193  	} {
   194  		sig := sig
   195  		t.Run(sig.String(), func(t *testing.T) {
   196  			s := newRuntimeValueLocationStack()
   197  			// pushCallFrame assumes that the parameters are already pushed.
   198  			for i := 0; i < sig.ParamNumInUint64; i++ {
   199  				_ = s.pushRuntimeValueLocationOnStack()
   200  			}
   201  
   202  			retAddr, stackBasePointer, fn := s.pushCallFrame(sig)
   203  
   204  			expOffset := uint64(callFrameOffset(sig))
   205  			require.Equal(t, expOffset, retAddr.stackPointer)
   206  			require.Equal(t, expOffset+1, stackBasePointer.stackPointer)
   207  			require.Equal(t, expOffset+2, fn.stackPointer)
   208  		})
   209  	}
   210  }
   211  
   212  func Test_usedRegistersMask(t *testing.T) {
   213  	for _, r := range append(unreservedVectorRegisters, unreservedGeneralPurposeRegisters...) {
   214  		mask := usedRegistersMask(0)
   215  		mask.add(r)
   216  		require.False(t, mask == 0)
   217  		require.True(t, mask.exist(r))
   218  		mask.remove(r)
   219  		require.True(t, mask == 0)
   220  		require.False(t, mask.exist(r))
   221  	}
   222  }
   223  
   224  func TestRuntimeValueLocation_cloneFrom(t *testing.T) {
   225  	t.Run("sp<cap", func(t *testing.T) {
   226  		v := runtimeValueLocationStack{sp: 7, stack: make([]runtimeValueLocation, 5, 10)}
   227  		orig := v.stack
   228  		v.cloneFrom(runtimeValueLocationStack{sp: 3, usedRegisters: 0xffff, stack: []runtimeValueLocation{
   229  			{register: 3}, {register: 2}, {register: 1},
   230  		}})
   231  		require.Equal(t, uint64(3), v.sp)
   232  		require.Equal(t, usedRegistersMask(0xffff), v.usedRegisters)
   233  		// Underlying stack shouldn't have changed since sp=3 < cap(v.stack).
   234  		require.Equal(t, &orig[0], &v.stack[0])
   235  		require.Equal(t, v.stack[0].register, asm.Register(3))
   236  		require.Equal(t, v.stack[1].register, asm.Register(2))
   237  		require.Equal(t, v.stack[2].register, asm.Register(1))
   238  	})
   239  	t.Run("sp=cap", func(t *testing.T) {
   240  		v := runtimeValueLocationStack{stack: make([]runtimeValueLocation, 0, 3)}
   241  		orig := v.stack[:cap(v.stack)]
   242  		v.cloneFrom(runtimeValueLocationStack{sp: 3, usedRegisters: 0xffff, stack: []runtimeValueLocation{
   243  			{register: 3}, {register: 2}, {register: 1},
   244  		}})
   245  		require.Equal(t, uint64(3), v.sp)
   246  		require.Equal(t, usedRegistersMask(0xffff), v.usedRegisters)
   247  		// Underlying stack shouldn't have changed since sp=3==cap(v.stack).
   248  		require.Equal(t, &orig[0], &v.stack[0])
   249  		require.Equal(t, v.stack[0].register, asm.Register(3))
   250  		require.Equal(t, v.stack[1].register, asm.Register(2))
   251  		require.Equal(t, v.stack[2].register, asm.Register(1))
   252  	})
   253  	t.Run("sp>cap", func(t *testing.T) {
   254  		v := runtimeValueLocationStack{stack: make([]runtimeValueLocation, 0, 3)}
   255  		orig := v.stack[:cap(v.stack)]
   256  		v.cloneFrom(runtimeValueLocationStack{sp: 5, usedRegisters: 0xffff, stack: []runtimeValueLocation{
   257  			{register: 5}, {register: 4}, {register: 3}, {register: 2}, {register: 1},
   258  		}})
   259  		require.Equal(t, uint64(5), v.sp)
   260  		require.Equal(t, usedRegistersMask(0xffff), v.usedRegisters)
   261  		// Underlying stack should have changed since sp=5>cap(v.stack).
   262  		require.NotEqual(t, &orig[0], &v.stack[0])
   263  		require.Equal(t, v.stack[0].register, asm.Register(5))
   264  		require.Equal(t, v.stack[1].register, asm.Register(4))
   265  		require.Equal(t, v.stack[2].register, asm.Register(3))
   266  		require.Equal(t, v.stack[3].register, asm.Register(2))
   267  		require.Equal(t, v.stack[4].register, asm.Register(1))
   268  	})
   269  }