github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/wasm/gofunc_test.go (about)

     1  package wasm
     2  
     3  import (
     4  	"context"
     5  	"math"
     6  	"testing"
     7  	"unsafe"
     8  
     9  	"github.com/tetratelabs/wazero/api"
    10  	"github.com/tetratelabs/wazero/internal/testing/require"
    11  )
    12  
    13  // testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors.
    14  var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary")
    15  
    16  func Test_parseGoFunc(t *testing.T) {
    17  	tests := []struct {
    18  		name              string
    19  		input             interface{}
    20  		expectNeedsModule bool
    21  		expectedType      *FunctionType
    22  	}{
    23  		{
    24  			name:         "() -> ()",
    25  			input:        func() {},
    26  			expectedType: &FunctionType{},
    27  		},
    28  		{
    29  			name:         "(ctx) -> ()",
    30  			input:        func(context.Context) {},
    31  			expectedType: &FunctionType{},
    32  		},
    33  		{
    34  			name:              "(ctx, mod) -> ()",
    35  			input:             func(context.Context, api.Module) {},
    36  			expectNeedsModule: true,
    37  			expectedType:      &FunctionType{},
    38  		},
    39  		{
    40  			name:         "all supported params and i32 result",
    41  			input:        func(uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
    42  			expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
    43  		},
    44  		{
    45  			name:         "all supported params and i32 result - (ctx)",
    46  			input:        func(context.Context, uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
    47  			expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
    48  		},
    49  		{
    50  			name:              "all supported params and i32 result - (ctx, mod)",
    51  			input:             func(context.Context, api.Module, uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
    52  			expectNeedsModule: true,
    53  			expectedType:      &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
    54  		},
    55  	}
    56  	for _, tt := range tests {
    57  		tc := tt
    58  
    59  		t.Run(tc.name, func(t *testing.T) {
    60  			paramTypes, resultTypes, code, err := parseGoReflectFunc(tc.input)
    61  			require.NoError(t, err)
    62  			_, isModuleFunc := code.GoFunc.(api.GoModuleFunction)
    63  			require.Equal(t, tc.expectNeedsModule, isModuleFunc)
    64  			require.Equal(t, tc.expectedType, &FunctionType{Params: paramTypes, Results: resultTypes})
    65  		})
    66  	}
    67  }
    68  
    69  func Test_parseGoFunc_Errors(t *testing.T) {
    70  	tests := []struct {
    71  		name        string
    72  		input       interface{}
    73  		expectedErr string
    74  	}{
    75  		{
    76  			name:        "module no context",
    77  			input:       func(api.Module) {},
    78  			expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context",
    79  		},
    80  		{
    81  			name:        "not a func",
    82  			input:       struct{}{},
    83  			expectedErr: "kind != func: struct",
    84  		},
    85  		{
    86  			name:        "unsupported param",
    87  			input:       func(context.Context, uint32, string) {},
    88  			expectedErr: "param[2] is unsupported: string",
    89  		},
    90  		{
    91  			name:        "unsupported result",
    92  			input:       func() string { return "" },
    93  			expectedErr: "result[0] is unsupported: string",
    94  		},
    95  		{
    96  			name:        "error result",
    97  			input:       func() error { return nil },
    98  			expectedErr: "result[0] is an error, which is unsupported",
    99  		},
   100  		{
   101  			name:        "incorrect order",
   102  			input:       func(api.Module, context.Context) error { return nil },
   103  			expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context",
   104  		},
   105  		{
   106  			name:        "multiple context.Context",
   107  			input:       func(context.Context, uint64, context.Context) error { return nil },
   108  			expectedErr: "param[2] is a context.Context, which may be defined only once as param[0]",
   109  		},
   110  		{
   111  			name:        "multiple wasm.Module",
   112  			input:       func(context.Context, api.Module, uint64, api.Module) error { return nil },
   113  			expectedErr: "param[3] is a api.Module, which may be defined only once as param[0]",
   114  		},
   115  	}
   116  
   117  	for _, tt := range tests {
   118  		tc := tt
   119  
   120  		t.Run(tc.name, func(t *testing.T) {
   121  			_, _, _, err := parseGoReflectFunc(tc.input)
   122  			require.EqualError(t, err, tc.expectedErr)
   123  		})
   124  	}
   125  }
   126  
   127  func Test_callGoFunc(t *testing.T) {
   128  	tPtr := uintptr(unsafe.Pointer(t))
   129  	inst := &ModuleInstance{}
   130  
   131  	tests := []struct {
   132  		name                         string
   133  		input                        interface{}
   134  		inputParams, expectedResults []uint64
   135  	}{
   136  		{
   137  			name:  "() -> ()",
   138  			input: func() {},
   139  		},
   140  		{
   141  			name: "(ctx) -> ()",
   142  			input: func(ctx context.Context) {
   143  				require.Equal(t, testCtx, ctx)
   144  			},
   145  		},
   146  		{
   147  			name: "(ctx, mod) -> ()",
   148  			input: func(ctx context.Context, m api.Module) {
   149  				require.Equal(t, testCtx, ctx)
   150  				require.Equal(t, inst, m)
   151  			},
   152  		},
   153  		{
   154  			name: "all supported params and i32 result",
   155  			input: func(v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
   156  				require.Equal(t, tPtr, v)
   157  				require.Equal(t, uint32(math.MaxUint32), w)
   158  				require.Equal(t, uint64(math.MaxUint64), x)
   159  				require.Equal(t, float32(math.MaxFloat32), y)
   160  				require.Equal(t, math.MaxFloat64, z)
   161  				return 100
   162  			},
   163  			inputParams: []uint64{
   164  				api.EncodeExternref(tPtr),
   165  				math.MaxUint32,
   166  				math.MaxUint64,
   167  				api.EncodeF32(math.MaxFloat32),
   168  				api.EncodeF64(math.MaxFloat64),
   169  			},
   170  			expectedResults: []uint64{100},
   171  		},
   172  		{
   173  			name: "all supported params and i32 result - (ctx)",
   174  			input: func(ctx context.Context, v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
   175  				require.Equal(t, testCtx, ctx)
   176  				require.Equal(t, tPtr, v)
   177  				require.Equal(t, uint32(math.MaxUint32), w)
   178  				require.Equal(t, uint64(math.MaxUint64), x)
   179  				require.Equal(t, float32(math.MaxFloat32), y)
   180  				require.Equal(t, math.MaxFloat64, z)
   181  				return 100
   182  			},
   183  			inputParams: []uint64{
   184  				api.EncodeExternref(tPtr),
   185  				math.MaxUint32,
   186  				math.MaxUint64,
   187  				api.EncodeF32(math.MaxFloat32),
   188  				api.EncodeF64(math.MaxFloat64),
   189  			},
   190  			expectedResults: []uint64{100},
   191  		},
   192  		{
   193  			name: "all supported params and i32 result - (ctx, mod)",
   194  			input: func(ctx context.Context, m api.Module, v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
   195  				require.Equal(t, testCtx, ctx)
   196  				require.Equal(t, inst, m)
   197  				require.Equal(t, tPtr, v)
   198  				require.Equal(t, uint32(math.MaxUint32), w)
   199  				require.Equal(t, uint64(math.MaxUint64), x)
   200  				require.Equal(t, float32(math.MaxFloat32), y)
   201  				require.Equal(t, math.MaxFloat64, z)
   202  				return 100
   203  			},
   204  			inputParams: []uint64{
   205  				api.EncodeExternref(tPtr),
   206  				math.MaxUint32,
   207  				math.MaxUint64,
   208  				api.EncodeF32(math.MaxFloat32),
   209  				api.EncodeF64(math.MaxFloat64),
   210  			},
   211  			expectedResults: []uint64{100},
   212  		},
   213  	}
   214  	for _, tt := range tests {
   215  		tc := tt
   216  
   217  		t.Run(tc.name, func(t *testing.T) {
   218  			_, _, code, err := parseGoReflectFunc(tc.input)
   219  			require.NoError(t, err)
   220  
   221  			resultLen := len(tc.expectedResults)
   222  			stackLen := len(tc.inputParams)
   223  			if resultLen > stackLen {
   224  				stackLen = resultLen
   225  			}
   226  			stack := make([]uint64, stackLen)
   227  			copy(stack, tc.inputParams)
   228  
   229  			switch code.GoFunc.(type) {
   230  			case api.GoFunction:
   231  				code.GoFunc.(api.GoFunction).Call(testCtx, stack)
   232  			case api.GoModuleFunction:
   233  				code.GoFunc.(api.GoModuleFunction).Call(testCtx, inst, stack)
   234  			default:
   235  				t.Fatal("unexpected type.")
   236  			}
   237  
   238  			var results []uint64
   239  			if resultLen > 0 {
   240  				results = stack[:resultLen]
   241  			}
   242  			require.Equal(t, tc.expectedResults, results)
   243  		})
   244  	}
   245  }