github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/internal/engine/compiler/compiler_value_location_test.go (about) 1 package compiler 2 3 import ( 4 "testing" 5 6 "github.com/wasilibs/wazerox/internal/asm" 7 "github.com/wasilibs/wazerox/internal/testing/require" 8 "github.com/wasilibs/wazerox/internal/wasm" 9 ) 10 11 func Test_isIntRegister(t *testing.T) { 12 for _, r := range unreservedGeneralPurposeRegisters { 13 require.True(t, isGeneralPurposeRegister(r)) 14 } 15 } 16 17 func Test_isVectorRegister(t *testing.T) { 18 for _, r := range unreservedVectorRegisters { 19 require.True(t, isVectorRegister(r)) 20 } 21 } 22 23 func TestRuntimeValueLocationStack_basic(t *testing.T) { 24 s := newRuntimeValueLocationStack() 25 // Push stack value. 26 loc := s.pushRuntimeValueLocationOnStack() 27 require.Equal(t, uint64(1), s.sp) 28 require.Equal(t, uint64(0), loc.stackPointer) 29 // Push the register value. 30 tmpReg := unreservedGeneralPurposeRegisters[0] 31 loc = s.pushRuntimeValueLocationOnRegister(tmpReg, runtimeValueTypeI64) 32 require.Equal(t, uint64(2), s.sp) 33 require.Equal(t, uint64(1), loc.stackPointer) 34 require.Equal(t, tmpReg, loc.register) 35 require.Equal(t, loc.valueType, runtimeValueTypeI64) 36 // markRegisterUsed. 37 tmpReg2 := unreservedGeneralPurposeRegisters[1] 38 s.markRegisterUsed(tmpReg2) 39 require.True(t, s.usedRegisters.exist(tmpReg2)) 40 // releaseRegister. 41 s.releaseRegister(loc) 42 require.False(t, s.usedRegisters.exist(loc.register)) 43 require.Equal(t, asm.NilRegister, loc.register) 44 // Check the max stack pointer. 45 for i := 0; i < 1000; i++ { 46 s.pushRuntimeValueLocationOnStack() 47 } 48 for i := 0; i < 1000; i++ { 49 s.pop() 50 } 51 require.Equal(t, uint64(1002), s.stackPointerCeil) 52 } 53 54 func TestRuntimeValueLocationStack_takeFreeRegister(t *testing.T) { 55 s := newRuntimeValueLocationStack() 56 // For int registers. 57 r, ok := s.takeFreeRegister(registerTypeGeneralPurpose) 58 require.True(t, ok) 59 require.True(t, isGeneralPurposeRegister(r)) 60 // Mark all the int registers used. 61 for _, r := range unreservedGeneralPurposeRegisters { 62 s.markRegisterUsed(r) 63 } 64 // Now we cannot take free ones for int. 65 _, ok = s.takeFreeRegister(registerTypeGeneralPurpose) 66 require.False(t, ok) 67 // But we still should be able to take float regs. 68 r, ok = s.takeFreeRegister(registerTypeVector) 69 require.True(t, ok) 70 require.True(t, isVectorRegister(r)) 71 // Mark all the float registers used. 72 for _, r := range unreservedVectorRegisters { 73 s.markRegisterUsed(r) 74 } 75 // Now we cannot take free ones for floats. 76 _, ok = s.takeFreeRegister(registerTypeVector) 77 require.False(t, ok) 78 } 79 80 func TestRuntimeValueLocationStack_takeStealTargetFromUsedRegister(t *testing.T) { 81 s := newRuntimeValueLocationStack() 82 intReg := unreservedGeneralPurposeRegisters[0] 83 floatReg := unreservedVectorRegisters[0] 84 intLocation := s.push(intReg, asm.ConditionalRegisterStateUnset) 85 floatLocation := s.push(floatReg, asm.ConditionalRegisterStateUnset) 86 // Take for float. 87 target, ok := s.takeStealTargetFromUsedRegister(registerTypeVector) 88 require.True(t, ok) 89 require.Equal(t, floatLocation, target) 90 // Take for ints. 91 target, ok = s.takeStealTargetFromUsedRegister(registerTypeGeneralPurpose) 92 require.True(t, ok) 93 require.Equal(t, intLocation, target) 94 // Pop float value. 95 popped := s.pop() 96 require.Equal(t, floatLocation, popped) 97 // Now we cannot find the steal target. 98 target, ok = s.takeStealTargetFromUsedRegister(registerTypeVector) 99 require.False(t, ok) 100 require.Nil(t, target) 101 // Pop int value. 102 popped = s.pop() 103 require.Equal(t, intLocation, popped) 104 // Now we cannot find the steal target. 105 target, ok = s.takeStealTargetFromUsedRegister(registerTypeGeneralPurpose) 106 require.False(t, ok) 107 require.Nil(t, target) 108 } 109 110 func TestRuntimeValueLocationStack_setupInitialStack(t *testing.T) { 111 const f32 = wasm.ValueTypeF32 112 tests := []struct { 113 name string 114 sig *wasm.FunctionType 115 expectedSP uint64 116 }{ 117 { 118 name: "no params / no results", 119 sig: &wasm.FunctionType{}, 120 expectedSP: callFrameDataSizeInUint64, 121 }, 122 { 123 name: "no results", 124 sig: &wasm.FunctionType{ 125 Params: []wasm.ValueType{f32, f32}, 126 ParamNumInUint64: 2, 127 }, 128 expectedSP: callFrameDataSizeInUint64 + 2, 129 }, 130 { 131 name: "no params", 132 sig: &wasm.FunctionType{ 133 Results: []wasm.ValueType{f32, f32}, 134 ResultNumInUint64: 2, 135 }, 136 expectedSP: callFrameDataSizeInUint64 + 2, 137 }, 138 { 139 name: "params == results", 140 sig: &wasm.FunctionType{ 141 Params: []wasm.ValueType{f32, f32}, 142 ParamNumInUint64: 2, 143 Results: []wasm.ValueType{f32, f32}, 144 ResultNumInUint64: 2, 145 }, 146 expectedSP: callFrameDataSizeInUint64 + 2, 147 }, 148 { 149 name: "params > results", 150 sig: &wasm.FunctionType{ 151 Params: []wasm.ValueType{f32, f32, f32}, 152 ParamNumInUint64: 3, 153 Results: []wasm.ValueType{f32, f32}, 154 ResultNumInUint64: 2, 155 }, 156 expectedSP: callFrameDataSizeInUint64 + 3, 157 }, 158 { 159 name: "params < results", 160 sig: &wasm.FunctionType{ 161 Params: []wasm.ValueType{f32}, 162 ParamNumInUint64: 1, 163 Results: []wasm.ValueType{f32, f32, f32}, 164 ResultNumInUint64: 3, 165 }, 166 expectedSP: callFrameDataSizeInUint64 + 3, 167 }, 168 } 169 170 for _, tc := range tests { 171 tc := tc 172 t.Run(tc.name, func(t *testing.T) { 173 s := newRuntimeValueLocationStack() 174 s.init(tc.sig) 175 require.Equal(t, tc.expectedSP, s.sp) 176 177 callFrameLocations := s.stack[s.sp-callFrameDataSizeInUint64 : s.sp] 178 for _, loc := range callFrameLocations { 179 require.Equal(t, runtimeValueTypeI64, loc.valueType) 180 } 181 }) 182 } 183 } 184 185 func TestRuntimeValueLocation_pushCallFrame(t *testing.T) { 186 for _, sig := range []*wasm.FunctionType{ 187 {ParamNumInUint64: 0, ResultNumInUint64: 1}, 188 {ParamNumInUint64: 1, ResultNumInUint64: 0}, 189 {ParamNumInUint64: 1, ResultNumInUint64: 1}, 190 {ParamNumInUint64: 0, ResultNumInUint64: 2}, 191 {ParamNumInUint64: 2, ResultNumInUint64: 0}, 192 {ParamNumInUint64: 2, ResultNumInUint64: 3}, 193 } { 194 sig := sig 195 t.Run(sig.String(), func(t *testing.T) { 196 s := newRuntimeValueLocationStack() 197 // pushCallFrame assumes that the parameters are already pushed. 198 for i := 0; i < sig.ParamNumInUint64; i++ { 199 _ = s.pushRuntimeValueLocationOnStack() 200 } 201 202 retAddr, stackBasePointer, fn := s.pushCallFrame(sig) 203 204 expOffset := uint64(callFrameOffset(sig)) 205 require.Equal(t, expOffset, retAddr.stackPointer) 206 require.Equal(t, expOffset+1, stackBasePointer.stackPointer) 207 require.Equal(t, expOffset+2, fn.stackPointer) 208 }) 209 } 210 } 211 212 func Test_usedRegistersMask(t *testing.T) { 213 for _, r := range append(unreservedVectorRegisters, unreservedGeneralPurposeRegisters...) { 214 mask := usedRegistersMask(0) 215 mask.add(r) 216 require.False(t, mask == 0) 217 require.True(t, mask.exist(r)) 218 mask.remove(r) 219 require.True(t, mask == 0) 220 require.False(t, mask.exist(r)) 221 } 222 } 223 224 func TestRuntimeValueLocation_cloneFrom(t *testing.T) { 225 t.Run("sp<cap", func(t *testing.T) { 226 v := runtimeValueLocationStack{sp: 7, stack: make([]runtimeValueLocation, 5, 10)} 227 orig := v.stack 228 v.cloneFrom(runtimeValueLocationStack{sp: 3, usedRegisters: 0xffff, stack: []runtimeValueLocation{ 229 {register: 3}, {register: 2}, {register: 1}, 230 }}) 231 require.Equal(t, uint64(3), v.sp) 232 require.Equal(t, usedRegistersMask(0xffff), v.usedRegisters) 233 // Underlying stack shouldn't have changed since sp=3 < cap(v.stack). 234 require.Equal(t, &orig[0], &v.stack[0]) 235 require.Equal(t, v.stack[0].register, asm.Register(3)) 236 require.Equal(t, v.stack[1].register, asm.Register(2)) 237 require.Equal(t, v.stack[2].register, asm.Register(1)) 238 }) 239 t.Run("sp=cap", func(t *testing.T) { 240 v := runtimeValueLocationStack{stack: make([]runtimeValueLocation, 0, 3)} 241 orig := v.stack[:cap(v.stack)] 242 v.cloneFrom(runtimeValueLocationStack{sp: 3, usedRegisters: 0xffff, stack: []runtimeValueLocation{ 243 {register: 3}, {register: 2}, {register: 1}, 244 }}) 245 require.Equal(t, uint64(3), v.sp) 246 require.Equal(t, usedRegistersMask(0xffff), v.usedRegisters) 247 // Underlying stack shouldn't have changed since sp=3==cap(v.stack). 248 require.Equal(t, &orig[0], &v.stack[0]) 249 require.Equal(t, v.stack[0].register, asm.Register(3)) 250 require.Equal(t, v.stack[1].register, asm.Register(2)) 251 require.Equal(t, v.stack[2].register, asm.Register(1)) 252 }) 253 t.Run("sp>cap", func(t *testing.T) { 254 v := runtimeValueLocationStack{stack: make([]runtimeValueLocation, 0, 3)} 255 orig := v.stack[:cap(v.stack)] 256 v.cloneFrom(runtimeValueLocationStack{sp: 5, usedRegisters: 0xffff, stack: []runtimeValueLocation{ 257 {register: 5}, {register: 4}, {register: 3}, {register: 2}, {register: 1}, 258 }}) 259 require.Equal(t, uint64(5), v.sp) 260 require.Equal(t, usedRegistersMask(0xffff), v.usedRegisters) 261 // Underlying stack should have changed since sp=5>cap(v.stack). 262 require.NotEqual(t, &orig[0], &v.stack[0]) 263 require.Equal(t, v.stack[0].register, asm.Register(5)) 264 require.Equal(t, v.stack[1].register, asm.Register(4)) 265 require.Equal(t, v.stack[2].register, asm.Register(3)) 266 require.Equal(t, v.stack[3].register, asm.Register(2)) 267 require.Equal(t, v.stack[4].register, asm.Register(1)) 268 }) 269 }