wa-lang.org/wazero@v1.0.2/internal/engine/compiler/compiler_value_location_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"testing"
     5  	"unsafe"
     6  
     7  	"wa-lang.org/wazero/internal/asm"
     8  	"wa-lang.org/wazero/internal/testing/require"
     9  	"wa-lang.org/wazero/internal/wasm"
    10  )
    11  
    12  func Test_isIntRegister(t *testing.T) {
    13  	for _, r := range unreservedGeneralPurposeRegisters {
    14  		require.True(t, isGeneralPurposeRegister(r))
    15  	}
    16  }
    17  
    18  func Test_isVectorRegister(t *testing.T) {
    19  	for _, r := range unreservedVectorRegisters {
    20  		require.True(t, isVectorRegister(r))
    21  	}
    22  }
    23  
    24  func TestRuntimeValueLocationStack_basic(t *testing.T) {
    25  	s := newRuntimeValueLocationStack()
    26  	// Push stack value.
    27  	loc := s.pushRuntimeValueLocationOnStack()
    28  	require.Equal(t, uint64(1), s.sp)
    29  	require.Equal(t, uint64(0), loc.stackPointer)
    30  	// Push the register value.
    31  	tmpReg := unreservedGeneralPurposeRegisters[0]
    32  	loc = s.pushRuntimeValueLocationOnRegister(tmpReg, runtimeValueTypeI64)
    33  	require.Equal(t, uint64(2), s.sp)
    34  	require.Equal(t, uint64(1), loc.stackPointer)
    35  	require.Equal(t, tmpReg, loc.register)
    36  	require.Equal(t, loc.valueType, runtimeValueTypeI64)
    37  	// markRegisterUsed.
    38  	tmpReg2 := unreservedGeneralPurposeRegisters[1]
    39  	s.markRegisterUsed(tmpReg2)
    40  	require.NotNil(t, s.usedRegisters[tmpReg2], tmpReg2)
    41  	// releaseRegister.
    42  	s.releaseRegister(loc)
    43  	require.Equal(t, s.usedRegisters[loc.register], struct{}{}, "expected %v to not contain %v", s.usedRegisters, loc.register)
    44  	require.Equal(t, asm.NilRegister, loc.register)
    45  	// Clone.
    46  	cloned := s.clone()
    47  	require.Equal(t, s.usedRegisters, cloned.usedRegisters)
    48  	require.Equal(t, s.unreservedGeneralPurposeRegisters, cloned.unreservedGeneralPurposeRegisters)
    49  	require.Equal(t, s.unreservedVectorRegisters, cloned.unreservedVectorRegisters)
    50  	require.Equal(t, len(s.stack), len(cloned.stack))
    51  	require.Equal(t, s.sp, cloned.sp)
    52  	for i := 0; i < int(s.sp); i++ {
    53  		actual, exp := s.stack[i], cloned.stack[i]
    54  		require.NotEqual(t, uintptr(unsafe.Pointer(exp)), uintptr(unsafe.Pointer(actual)))
    55  	}
    56  	// Check the max stack pointer.
    57  	for i := 0; i < 1000; i++ {
    58  		s.pushRuntimeValueLocationOnStack()
    59  	}
    60  	for i := 0; i < 1000; i++ {
    61  		s.pop()
    62  	}
    63  	require.Equal(t, uint64(1002), s.stackPointerCeil)
    64  }
    65  
    66  func TestRuntimeValueLocationStack_takeFreeRegister(t *testing.T) {
    67  	s := newRuntimeValueLocationStack()
    68  	// For int registers.
    69  	r, ok := s.takeFreeRegister(registerTypeGeneralPurpose)
    70  	require.True(t, ok)
    71  	require.True(t, isGeneralPurposeRegister(r))
    72  	// Mark all the int registers used.
    73  	for _, r := range unreservedGeneralPurposeRegisters {
    74  		s.markRegisterUsed(r)
    75  	}
    76  	// Now we cannot take free ones for int.
    77  	_, ok = s.takeFreeRegister(registerTypeGeneralPurpose)
    78  	require.False(t, ok)
    79  	// But we still should be able to take float regs.
    80  	r, ok = s.takeFreeRegister(registerTypeVector)
    81  	require.True(t, ok)
    82  	require.True(t, isVectorRegister(r))
    83  	// Mark all the float registers used.
    84  	for _, r := range unreservedVectorRegisters {
    85  		s.markRegisterUsed(r)
    86  	}
    87  	// Now we cannot take free ones for floats.
    88  	_, ok = s.takeFreeRegister(registerTypeVector)
    89  	require.False(t, ok)
    90  }
    91  
    92  func TestRuntimeValueLocationStack_takeStealTargetFromUsedRegister(t *testing.T) {
    93  	s := newRuntimeValueLocationStack()
    94  	intReg := unreservedGeneralPurposeRegisters[0]
    95  	intLocation := &runtimeValueLocation{register: intReg}
    96  	floatReg := unreservedVectorRegisters[0]
    97  	floatLocation := &runtimeValueLocation{register: floatReg}
    98  	s.push(intLocation)
    99  	s.push(floatLocation)
   100  	// Take for float.
   101  	target, ok := s.takeStealTargetFromUsedRegister(registerTypeVector)
   102  	require.True(t, ok)
   103  	require.Equal(t, floatLocation, target)
   104  	// Take for ints.
   105  	target, ok = s.takeStealTargetFromUsedRegister(registerTypeGeneralPurpose)
   106  	require.True(t, ok)
   107  	require.Equal(t, intLocation, target)
   108  	// Pop float value.
   109  	popped := s.pop()
   110  	require.Equal(t, floatLocation, popped)
   111  	// Now we cannot find the steal target.
   112  	target, ok = s.takeStealTargetFromUsedRegister(registerTypeVector)
   113  	require.False(t, ok)
   114  	require.Nil(t, target)
   115  	// Pop int value.
   116  	popped = s.pop()
   117  	require.Equal(t, intLocation, popped)
   118  	// Now we cannot find the steal target.
   119  	target, ok = s.takeStealTargetFromUsedRegister(registerTypeGeneralPurpose)
   120  	require.False(t, ok)
   121  	require.Nil(t, target)
   122  }
   123  
   124  func TestRuntimeValueLocationStack_setupInitialStack(t *testing.T) {
   125  	const f32 = wasm.ValueTypeF32
   126  	tests := []struct {
   127  		name       string
   128  		sig        *wasm.FunctionType
   129  		expectedSP uint64
   130  	}{
   131  		{
   132  			name:       "no params / no results",
   133  			sig:        &wasm.FunctionType{},
   134  			expectedSP: callFrameDataSizeInUint64,
   135  		},
   136  		{
   137  			name: "no results",
   138  			sig: &wasm.FunctionType{
   139  				Params:           []wasm.ValueType{f32, f32},
   140  				ParamNumInUint64: 2,
   141  			},
   142  			expectedSP: callFrameDataSizeInUint64 + 2,
   143  		},
   144  		{
   145  			name: "no params",
   146  			sig: &wasm.FunctionType{
   147  				Results:           []wasm.ValueType{f32, f32},
   148  				ResultNumInUint64: 2,
   149  			},
   150  			expectedSP: callFrameDataSizeInUint64 + 2,
   151  		},
   152  		{
   153  			name: "params == results",
   154  			sig: &wasm.FunctionType{
   155  				Params:            []wasm.ValueType{f32, f32},
   156  				ParamNumInUint64:  2,
   157  				Results:           []wasm.ValueType{f32, f32},
   158  				ResultNumInUint64: 2,
   159  			},
   160  			expectedSP: callFrameDataSizeInUint64 + 2,
   161  		},
   162  		{
   163  			name: "params > results",
   164  			sig: &wasm.FunctionType{
   165  				Params:            []wasm.ValueType{f32, f32, f32},
   166  				ParamNumInUint64:  3,
   167  				Results:           []wasm.ValueType{f32, f32},
   168  				ResultNumInUint64: 2,
   169  			},
   170  			expectedSP: callFrameDataSizeInUint64 + 3,
   171  		},
   172  		{
   173  			name: "params <  results",
   174  			sig: &wasm.FunctionType{
   175  				Params:            []wasm.ValueType{f32},
   176  				ParamNumInUint64:  1,
   177  				Results:           []wasm.ValueType{f32, f32, f32},
   178  				ResultNumInUint64: 3,
   179  			},
   180  			expectedSP: callFrameDataSizeInUint64 + 3,
   181  		},
   182  	}
   183  
   184  	for _, tc := range tests {
   185  		tc := tc
   186  		t.Run(tc.name, func(t *testing.T) {
   187  			s := newRuntimeValueLocationStack()
   188  			s.init(tc.sig)
   189  			require.Equal(t, tc.expectedSP, s.sp)
   190  
   191  			callFrameLocations := s.stack[s.sp-callFrameDataSizeInUint64 : s.sp]
   192  			for _, loc := range callFrameLocations {
   193  				require.Equal(t, runtimeValueTypeI64, loc.valueType)
   194  			}
   195  		})
   196  	}
   197  }
   198  
   199  func TestRuntimeValueLocation_pushCallFrame(t *testing.T) {
   200  	for _, sig := range []*wasm.FunctionType{
   201  		{ParamNumInUint64: 0, ResultNumInUint64: 1},
   202  		{ParamNumInUint64: 1, ResultNumInUint64: 0},
   203  		{ParamNumInUint64: 1, ResultNumInUint64: 1},
   204  		{ParamNumInUint64: 0, ResultNumInUint64: 2},
   205  		{ParamNumInUint64: 2, ResultNumInUint64: 0},
   206  		{ParamNumInUint64: 2, ResultNumInUint64: 3},
   207  	} {
   208  		sig := sig
   209  		t.Run(sig.String(), func(t *testing.T) {
   210  			s := newRuntimeValueLocationStack()
   211  			// pushCallFrame assumes that the parameters are already pushed.
   212  			s.sp += uint64(sig.ParamNumInUint64)
   213  
   214  			retAddr, stackBasePointer, fn := s.pushCallFrame(sig)
   215  
   216  			expOffset := uint64(callFrameOffset(sig))
   217  			require.Equal(t, expOffset, retAddr.stackPointer)
   218  			require.Equal(t, expOffset+1, stackBasePointer.stackPointer)
   219  			require.Equal(t, expOffset+2, fn.stackPointer)
   220  		})
   221  	}
   222  }