github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/wasm/gofunc_test.go (about) 1 package wasm 2 3 import ( 4 "context" 5 "math" 6 "testing" 7 "unsafe" 8 9 "github.com/bananabytelabs/wazero/api" 10 "github.com/bananabytelabs/wazero/internal/testing/require" 11 ) 12 13 // testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. 14 var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") 15 16 func Test_parseGoFunc(t *testing.T) { 17 tests := []struct { 18 name string 19 input interface{} 20 expectNeedsModule bool 21 expectedType *FunctionType 22 }{ 23 { 24 name: "() -> ()", 25 input: func() {}, 26 expectedType: &FunctionType{}, 27 }, 28 { 29 name: "(ctx) -> ()", 30 input: func(context.Context) {}, 31 expectedType: &FunctionType{}, 32 }, 33 { 34 name: "(ctx, mod) -> ()", 35 input: func(context.Context, api.Module) {}, 36 expectNeedsModule: true, 37 expectedType: &FunctionType{}, 38 }, 39 { 40 name: "all supported params and i32 result", 41 input: func(uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, 42 expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}}, 43 }, 44 { 45 name: "all supported params and i32 result - (ctx)", 46 input: func(context.Context, uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, 47 expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}}, 48 }, 49 { 50 name: "all supported params and i32 result - (ctx, mod)", 51 input: func(context.Context, api.Module, uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, 52 expectNeedsModule: true, 53 expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}}, 54 }, 55 } 56 for _, tt := range tests { 57 tc := tt 58 59 t.Run(tc.name, func(t *testing.T) { 60 paramTypes, resultTypes, code, err := parseGoReflectFunc(tc.input) 61 require.NoError(t, err) 62 _, isModuleFunc := code.GoFunc.(api.GoModuleFunction) 63 require.Equal(t, tc.expectNeedsModule, isModuleFunc) 64 require.Equal(t, tc.expectedType, &FunctionType{Params: paramTypes, Results: resultTypes}) 65 }) 66 } 67 } 68 69 func Test_parseGoFunc_Errors(t *testing.T) { 70 tests := []struct { 71 name string 72 input interface{} 73 expectedErr string 74 }{ 75 { 76 name: "module no context", 77 input: func(api.Module) {}, 78 expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context", 79 }, 80 { 81 name: "not a func", 82 input: struct{}{}, 83 expectedErr: "kind != func: struct", 84 }, 85 { 86 name: "unsupported param", 87 input: func(context.Context, uint32, string) {}, 88 expectedErr: "param[2] is unsupported: string", 89 }, 90 { 91 name: "unsupported result", 92 input: func() string { return "" }, 93 expectedErr: "result[0] is unsupported: string", 94 }, 95 { 96 name: "error result", 97 input: func() error { return nil }, 98 expectedErr: "result[0] is an error, which is unsupported", 99 }, 100 { 101 name: "incorrect order", 102 input: func(api.Module, context.Context) error { return nil }, 103 expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context", 104 }, 105 { 106 name: "multiple context.Context", 107 input: func(context.Context, uint64, context.Context) error { return nil }, 108 expectedErr: "param[2] is a context.Context, which may be defined only once as param[0]", 109 }, 110 { 111 name: "multiple wasm.Module", 112 input: func(context.Context, api.Module, uint64, api.Module) error { return nil }, 113 expectedErr: "param[3] is a api.Module, which may be defined only once as param[0]", 114 }, 115 } 116 117 for _, tt := range tests { 118 tc := tt 119 120 t.Run(tc.name, func(t *testing.T) { 121 _, _, _, err := parseGoReflectFunc(tc.input) 122 require.EqualError(t, err, tc.expectedErr) 123 }) 124 } 125 } 126 127 func Test_callGoFunc(t *testing.T) { 128 tPtr := uintptr(unsafe.Pointer(t)) 129 inst := &ModuleInstance{} 130 131 tests := []struct { 132 name string 133 input interface{} 134 inputParams, expectedResults []uint64 135 }{ 136 { 137 name: "() -> ()", 138 input: func() {}, 139 }, 140 { 141 name: "(ctx) -> ()", 142 input: func(ctx context.Context) { 143 require.Equal(t, testCtx, ctx) 144 }, 145 }, 146 { 147 name: "(ctx, mod) -> ()", 148 input: func(ctx context.Context, m api.Module) { 149 require.Equal(t, testCtx, ctx) 150 require.Equal(t, inst, m) 151 }, 152 }, 153 { 154 name: "all supported params and i32 result", 155 input: func(v uintptr, w uint32, x uint64, y float32, z float64) uint32 { 156 require.Equal(t, tPtr, v) 157 require.Equal(t, uint32(math.MaxUint32), w) 158 require.Equal(t, uint64(math.MaxUint64), x) 159 require.Equal(t, float32(math.MaxFloat32), y) 160 require.Equal(t, math.MaxFloat64, z) 161 return 100 162 }, 163 inputParams: []uint64{ 164 api.EncodeExternref(tPtr), 165 math.MaxUint32, 166 math.MaxUint64, 167 api.EncodeF32(math.MaxFloat32), 168 api.EncodeF64(math.MaxFloat64), 169 }, 170 expectedResults: []uint64{100}, 171 }, 172 { 173 name: "all supported params and i32 result - (ctx)", 174 input: func(ctx context.Context, v uintptr, w uint32, x uint64, y float32, z float64) uint32 { 175 require.Equal(t, testCtx, ctx) 176 require.Equal(t, tPtr, v) 177 require.Equal(t, uint32(math.MaxUint32), w) 178 require.Equal(t, uint64(math.MaxUint64), x) 179 require.Equal(t, float32(math.MaxFloat32), y) 180 require.Equal(t, math.MaxFloat64, z) 181 return 100 182 }, 183 inputParams: []uint64{ 184 api.EncodeExternref(tPtr), 185 math.MaxUint32, 186 math.MaxUint64, 187 api.EncodeF32(math.MaxFloat32), 188 api.EncodeF64(math.MaxFloat64), 189 }, 190 expectedResults: []uint64{100}, 191 }, 192 { 193 name: "all supported params and i32 result - (ctx, mod)", 194 input: func(ctx context.Context, m api.Module, v uintptr, w uint32, x uint64, y float32, z float64) uint32 { 195 require.Equal(t, testCtx, ctx) 196 require.Equal(t, inst, m) 197 require.Equal(t, tPtr, v) 198 require.Equal(t, uint32(math.MaxUint32), w) 199 require.Equal(t, uint64(math.MaxUint64), x) 200 require.Equal(t, float32(math.MaxFloat32), y) 201 require.Equal(t, math.MaxFloat64, z) 202 return 100 203 }, 204 inputParams: []uint64{ 205 api.EncodeExternref(tPtr), 206 math.MaxUint32, 207 math.MaxUint64, 208 api.EncodeF32(math.MaxFloat32), 209 api.EncodeF64(math.MaxFloat64), 210 }, 211 expectedResults: []uint64{100}, 212 }, 213 } 214 for _, tt := range tests { 215 tc := tt 216 217 t.Run(tc.name, func(t *testing.T) { 218 _, _, code, err := parseGoReflectFunc(tc.input) 219 require.NoError(t, err) 220 221 resultLen := len(tc.expectedResults) 222 stackLen := len(tc.inputParams) 223 if resultLen > stackLen { 224 stackLen = resultLen 225 } 226 stack := make([]uint64, stackLen) 227 copy(stack, tc.inputParams) 228 229 switch code.GoFunc.(type) { 230 case api.GoFunction: 231 code.GoFunc.(api.GoFunction).Call(testCtx, stack) 232 case api.GoModuleFunction: 233 code.GoFunc.(api.GoModuleFunction).Call(testCtx, inst, stack) 234 default: 235 t.Fatal("unexpected type.") 236 } 237 238 var results []uint64 239 if resultLen > 0 { 240 results = stack[:resultLen] 241 } 242 require.Equal(t, tc.expectedResults, results) 243 }) 244 } 245 }