wa-lang.org/wazero@v1.0.2/internal/wasm/gofunc_test.go (about) 1 package wasm 2 3 import ( 4 "context" 5 "math" 6 "testing" 7 "unsafe" 8 9 "wa-lang.org/wazero/api" 10 "wa-lang.org/wazero/internal/testing/require" 11 ) 12 13 // testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. 14 var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") 15 16 func Test_parseGoFunc(t *testing.T) { 17 tests := []struct { 18 name string 19 input interface{} 20 expectNeedsModule bool 21 expectedType *FunctionType 22 }{ 23 { 24 name: "() -> ()", 25 input: func() {}, 26 expectedType: &FunctionType{}, 27 }, 28 { 29 name: "(ctx) -> ()", 30 input: func(context.Context) {}, 31 expectedType: &FunctionType{}, 32 }, 33 { 34 name: "(ctx, mod) -> ()", 35 input: func(context.Context, api.Module) {}, 36 expectNeedsModule: true, 37 expectedType: &FunctionType{}, 38 }, 39 { 40 name: "all supported params and i32 result", 41 input: func(uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, 42 expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}}, 43 }, 44 { 45 name: "all supported params and i32 result - (ctx)", 46 input: func(context.Context, uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, 47 expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}}, 48 }, 49 { 50 name: "all supported params and i32 result - (ctx, mod)", 51 input: func(context.Context, api.Module, uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, 52 expectNeedsModule: true, 53 expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}}, 54 }, 55 } 56 for _, tt := range tests { 57 tc := tt 58 59 t.Run(tc.name, func(t *testing.T) { 60 paramTypes, resultTypes, code, err := parseGoReflectFunc(tc.input) 61 require.NoError(t, err) 62 _, isModuleFunc := code.GoFunc.(api.GoModuleFunction) 63 require.Equal(t, tc.expectNeedsModule, isModuleFunc) 64 require.Equal(t, tc.expectedType, &FunctionType{Params: paramTypes, Results: resultTypes}) 65 }) 66 } 67 } 68 69 func Test_parseGoFunc_Errors(t *testing.T) { 70 tests := []struct { 71 name string 72 input interface{} 73 expectedErr string 74 }{ 75 { 76 name: "module no context", 77 input: func(api.Module) {}, 78 expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context", 79 }, 80 { 81 name: "not a func", 82 input: struct{}{}, 83 expectedErr: "kind != func: struct", 84 }, 85 { 86 name: "unsupported param", 87 input: func(context.Context, uint32, string) {}, 88 expectedErr: "param[2] is unsupported: string", 89 }, 90 { 91 name: "unsupported result", 92 input: func() string { return "" }, 93 expectedErr: "result[0] is unsupported: string", 94 }, 95 { 96 name: "error result", 97 input: func() error { return nil }, 98 expectedErr: "result[0] is an error, which is unsupported", 99 }, 100 { 101 name: "incorrect order", 102 input: func(api.Module, context.Context) error { return nil }, 103 expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context", 104 }, 105 { 106 name: "multiple context.Context", 107 input: func(context.Context, uint64, context.Context) error { return nil }, 108 expectedErr: "param[2] is a context.Context, which may be defined only once as param[0]", 109 }, 110 { 111 name: "multiple wasm.Module", 112 input: func(context.Context, api.Module, uint64, api.Module) error { return nil }, 113 expectedErr: "param[3] is a api.Module, which may be defined only once as param[0]", 114 }, 115 } 116 117 for _, tt := range tests { 118 tc := tt 119 120 t.Run(tc.name, func(t *testing.T) { 121 _, _, _, err := parseGoReflectFunc(tc.input) 122 require.EqualError(t, err, tc.expectedErr) 123 }) 124 } 125 } 126 127 // stack simulates the value stack in a way easy to be tested. 128 type stack struct { 129 vals []uint64 130 } 131 132 func (s *stack) pop() (result uint64) { 133 stackTopIndex := len(s.vals) - 1 134 result = s.vals[stackTopIndex] 135 s.vals = s.vals[:stackTopIndex] 136 return 137 } 138 139 func TestPopValues(t *testing.T) { 140 stackVals := []uint64{1, 2, 3, 4, 5, 6, 7} 141 tests := []struct { 142 name string 143 count int 144 expected []uint64 145 }{ 146 { 147 name: "pop zero doesn't allocate a slice ", 148 }, 149 { 150 name: "pop 1", 151 count: 1, 152 expected: []uint64{7}, 153 }, 154 { 155 name: "pop 2", 156 count: 2, 157 expected: []uint64{6, 7}, 158 }, 159 { 160 name: "pop 3", 161 count: 3, 162 expected: []uint64{5, 6, 7}, 163 }, 164 } 165 166 for _, tt := range tests { 167 tc := tt 168 169 t.Run(tc.name, func(t *testing.T) { 170 vals := PopValues(tc.count, (&stack{stackVals}).pop) 171 require.Equal(t, tc.expected, vals) 172 }) 173 } 174 } 175 176 func Test_callGoFunc(t *testing.T) { 177 tPtr := uintptr(unsafe.Pointer(t)) 178 callCtx := &CallContext{} 179 180 tests := []struct { 181 name string 182 input interface{} 183 inputParams, expectedResults []uint64 184 }{ 185 { 186 name: "() -> ()", 187 input: func() {}, 188 }, 189 { 190 name: "(ctx) -> ()", 191 input: func(ctx context.Context) { 192 require.Equal(t, testCtx, ctx) 193 }, 194 }, 195 { 196 name: "(ctx, mod) -> ()", 197 input: func(ctx context.Context, m api.Module) { 198 require.Equal(t, testCtx, ctx) 199 require.Equal(t, callCtx, m) 200 }, 201 }, 202 { 203 name: "all supported params and i32 result", 204 input: func(v uintptr, w uint32, x uint64, y float32, z float64) uint32 { 205 require.Equal(t, tPtr, v) 206 require.Equal(t, uint32(math.MaxUint32), w) 207 require.Equal(t, uint64(math.MaxUint64), x) 208 require.Equal(t, float32(math.MaxFloat32), y) 209 require.Equal(t, math.MaxFloat64, z) 210 return 100 211 }, 212 inputParams: []uint64{ 213 api.EncodeExternref(tPtr), 214 math.MaxUint32, 215 math.MaxUint64, 216 api.EncodeF32(math.MaxFloat32), 217 api.EncodeF64(math.MaxFloat64), 218 }, 219 expectedResults: []uint64{100}, 220 }, 221 { 222 name: "all supported params and i32 result - (ctx)", 223 input: func(ctx context.Context, v uintptr, w uint32, x uint64, y float32, z float64) uint32 { 224 require.Equal(t, testCtx, ctx) 225 require.Equal(t, tPtr, v) 226 require.Equal(t, uint32(math.MaxUint32), w) 227 require.Equal(t, uint64(math.MaxUint64), x) 228 require.Equal(t, float32(math.MaxFloat32), y) 229 require.Equal(t, math.MaxFloat64, z) 230 return 100 231 }, 232 inputParams: []uint64{ 233 api.EncodeExternref(tPtr), 234 math.MaxUint32, 235 math.MaxUint64, 236 api.EncodeF32(math.MaxFloat32), 237 api.EncodeF64(math.MaxFloat64), 238 }, 239 expectedResults: []uint64{100}, 240 }, 241 { 242 name: "all supported params and i32 result - (ctx, mod)", 243 input: func(ctx context.Context, m api.Module, v uintptr, w uint32, x uint64, y float32, z float64) uint32 { 244 require.Equal(t, testCtx, ctx) 245 require.Equal(t, callCtx, m) 246 require.Equal(t, tPtr, v) 247 require.Equal(t, uint32(math.MaxUint32), w) 248 require.Equal(t, uint64(math.MaxUint64), x) 249 require.Equal(t, float32(math.MaxFloat32), y) 250 require.Equal(t, math.MaxFloat64, z) 251 return 100 252 }, 253 inputParams: []uint64{ 254 api.EncodeExternref(tPtr), 255 math.MaxUint32, 256 math.MaxUint64, 257 api.EncodeF32(math.MaxFloat32), 258 api.EncodeF64(math.MaxFloat64), 259 }, 260 expectedResults: []uint64{100}, 261 }, 262 } 263 for _, tt := range tests { 264 tc := tt 265 266 t.Run(tc.name, func(t *testing.T) { 267 _, _, code, err := parseGoReflectFunc(tc.input) 268 require.NoError(t, err) 269 270 resultLen := len(tc.expectedResults) 271 stackLen := len(tc.inputParams) 272 if resultLen > stackLen { 273 stackLen = resultLen 274 } 275 stack := make([]uint64, stackLen) 276 copy(stack, tc.inputParams) 277 278 switch code.GoFunc.(type) { 279 case api.GoFunction: 280 code.GoFunc.(api.GoFunction).Call(testCtx, stack) 281 case api.GoModuleFunction: 282 code.GoFunc.(api.GoModuleFunction).Call(testCtx, callCtx, stack) 283 default: 284 t.Fatal("unexpected type.") 285 } 286 287 var results []uint64 288 if resultLen > 0 { 289 results = stack[:resultLen] 290 } 291 require.Equal(t, tc.expectedResults, results) 292 }) 293 } 294 }