wa-lang.org/wazero@v1.0.2/internal/wasm/gofunc_test.go (about)

     1  package wasm
     2  
     3  import (
     4  	"context"
     5  	"math"
     6  	"testing"
     7  	"unsafe"
     8  
     9  	"wa-lang.org/wazero/api"
    10  	"wa-lang.org/wazero/internal/testing/require"
    11  )
    12  
    13  // testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors.
    14  var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary")
    15  
    16  func Test_parseGoFunc(t *testing.T) {
    17  	tests := []struct {
    18  		name              string
    19  		input             interface{}
    20  		expectNeedsModule bool
    21  		expectedType      *FunctionType
    22  	}{
    23  		{
    24  			name:         "() -> ()",
    25  			input:        func() {},
    26  			expectedType: &FunctionType{},
    27  		},
    28  		{
    29  			name:         "(ctx) -> ()",
    30  			input:        func(context.Context) {},
    31  			expectedType: &FunctionType{},
    32  		},
    33  		{
    34  			name:              "(ctx, mod) -> ()",
    35  			input:             func(context.Context, api.Module) {},
    36  			expectNeedsModule: true,
    37  			expectedType:      &FunctionType{},
    38  		},
    39  		{
    40  			name:         "all supported params and i32 result",
    41  			input:        func(uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
    42  			expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
    43  		},
    44  		{
    45  			name:         "all supported params and i32 result - (ctx)",
    46  			input:        func(context.Context, uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
    47  			expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
    48  		},
    49  		{
    50  			name:              "all supported params and i32 result - (ctx, mod)",
    51  			input:             func(context.Context, api.Module, uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
    52  			expectNeedsModule: true,
    53  			expectedType:      &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
    54  		},
    55  	}
    56  	for _, tt := range tests {
    57  		tc := tt
    58  
    59  		t.Run(tc.name, func(t *testing.T) {
    60  			paramTypes, resultTypes, code, err := parseGoReflectFunc(tc.input)
    61  			require.NoError(t, err)
    62  			_, isModuleFunc := code.GoFunc.(api.GoModuleFunction)
    63  			require.Equal(t, tc.expectNeedsModule, isModuleFunc)
    64  			require.Equal(t, tc.expectedType, &FunctionType{Params: paramTypes, Results: resultTypes})
    65  		})
    66  	}
    67  }
    68  
    69  func Test_parseGoFunc_Errors(t *testing.T) {
    70  	tests := []struct {
    71  		name        string
    72  		input       interface{}
    73  		expectedErr string
    74  	}{
    75  		{
    76  			name:        "module no context",
    77  			input:       func(api.Module) {},
    78  			expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context",
    79  		},
    80  		{
    81  			name:        "not a func",
    82  			input:       struct{}{},
    83  			expectedErr: "kind != func: struct",
    84  		},
    85  		{
    86  			name:        "unsupported param",
    87  			input:       func(context.Context, uint32, string) {},
    88  			expectedErr: "param[2] is unsupported: string",
    89  		},
    90  		{
    91  			name:        "unsupported result",
    92  			input:       func() string { return "" },
    93  			expectedErr: "result[0] is unsupported: string",
    94  		},
    95  		{
    96  			name:        "error result",
    97  			input:       func() error { return nil },
    98  			expectedErr: "result[0] is an error, which is unsupported",
    99  		},
   100  		{
   101  			name:        "incorrect order",
   102  			input:       func(api.Module, context.Context) error { return nil },
   103  			expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context",
   104  		},
   105  		{
   106  			name:        "multiple context.Context",
   107  			input:       func(context.Context, uint64, context.Context) error { return nil },
   108  			expectedErr: "param[2] is a context.Context, which may be defined only once as param[0]",
   109  		},
   110  		{
   111  			name:        "multiple wasm.Module",
   112  			input:       func(context.Context, api.Module, uint64, api.Module) error { return nil },
   113  			expectedErr: "param[3] is a api.Module, which may be defined only once as param[0]",
   114  		},
   115  	}
   116  
   117  	for _, tt := range tests {
   118  		tc := tt
   119  
   120  		t.Run(tc.name, func(t *testing.T) {
   121  			_, _, _, err := parseGoReflectFunc(tc.input)
   122  			require.EqualError(t, err, tc.expectedErr)
   123  		})
   124  	}
   125  }
   126  
   127  // stack simulates the value stack in a way easy to be tested.
   128  type stack struct {
   129  	vals []uint64
   130  }
   131  
   132  func (s *stack) pop() (result uint64) {
   133  	stackTopIndex := len(s.vals) - 1
   134  	result = s.vals[stackTopIndex]
   135  	s.vals = s.vals[:stackTopIndex]
   136  	return
   137  }
   138  
   139  func TestPopValues(t *testing.T) {
   140  	stackVals := []uint64{1, 2, 3, 4, 5, 6, 7}
   141  	tests := []struct {
   142  		name     string
   143  		count    int
   144  		expected []uint64
   145  	}{
   146  		{
   147  			name: "pop zero doesn't allocate a slice ",
   148  		},
   149  		{
   150  			name:     "pop 1",
   151  			count:    1,
   152  			expected: []uint64{7},
   153  		},
   154  		{
   155  			name:     "pop 2",
   156  			count:    2,
   157  			expected: []uint64{6, 7},
   158  		},
   159  		{
   160  			name:     "pop 3",
   161  			count:    3,
   162  			expected: []uint64{5, 6, 7},
   163  		},
   164  	}
   165  
   166  	for _, tt := range tests {
   167  		tc := tt
   168  
   169  		t.Run(tc.name, func(t *testing.T) {
   170  			vals := PopValues(tc.count, (&stack{stackVals}).pop)
   171  			require.Equal(t, tc.expected, vals)
   172  		})
   173  	}
   174  }
   175  
   176  func Test_callGoFunc(t *testing.T) {
   177  	tPtr := uintptr(unsafe.Pointer(t))
   178  	callCtx := &CallContext{}
   179  
   180  	tests := []struct {
   181  		name                         string
   182  		input                        interface{}
   183  		inputParams, expectedResults []uint64
   184  	}{
   185  		{
   186  			name:  "() -> ()",
   187  			input: func() {},
   188  		},
   189  		{
   190  			name: "(ctx) -> ()",
   191  			input: func(ctx context.Context) {
   192  				require.Equal(t, testCtx, ctx)
   193  			},
   194  		},
   195  		{
   196  			name: "(ctx, mod) -> ()",
   197  			input: func(ctx context.Context, m api.Module) {
   198  				require.Equal(t, testCtx, ctx)
   199  				require.Equal(t, callCtx, m)
   200  			},
   201  		},
   202  		{
   203  			name: "all supported params and i32 result",
   204  			input: func(v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
   205  				require.Equal(t, tPtr, v)
   206  				require.Equal(t, uint32(math.MaxUint32), w)
   207  				require.Equal(t, uint64(math.MaxUint64), x)
   208  				require.Equal(t, float32(math.MaxFloat32), y)
   209  				require.Equal(t, math.MaxFloat64, z)
   210  				return 100
   211  			},
   212  			inputParams: []uint64{
   213  				api.EncodeExternref(tPtr),
   214  				math.MaxUint32,
   215  				math.MaxUint64,
   216  				api.EncodeF32(math.MaxFloat32),
   217  				api.EncodeF64(math.MaxFloat64),
   218  			},
   219  			expectedResults: []uint64{100},
   220  		},
   221  		{
   222  			name: "all supported params and i32 result - (ctx)",
   223  			input: func(ctx context.Context, v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
   224  				require.Equal(t, testCtx, ctx)
   225  				require.Equal(t, tPtr, v)
   226  				require.Equal(t, uint32(math.MaxUint32), w)
   227  				require.Equal(t, uint64(math.MaxUint64), x)
   228  				require.Equal(t, float32(math.MaxFloat32), y)
   229  				require.Equal(t, math.MaxFloat64, z)
   230  				return 100
   231  			},
   232  			inputParams: []uint64{
   233  				api.EncodeExternref(tPtr),
   234  				math.MaxUint32,
   235  				math.MaxUint64,
   236  				api.EncodeF32(math.MaxFloat32),
   237  				api.EncodeF64(math.MaxFloat64),
   238  			},
   239  			expectedResults: []uint64{100},
   240  		},
   241  		{
   242  			name: "all supported params and i32 result - (ctx, mod)",
   243  			input: func(ctx context.Context, m api.Module, v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
   244  				require.Equal(t, testCtx, ctx)
   245  				require.Equal(t, callCtx, m)
   246  				require.Equal(t, tPtr, v)
   247  				require.Equal(t, uint32(math.MaxUint32), w)
   248  				require.Equal(t, uint64(math.MaxUint64), x)
   249  				require.Equal(t, float32(math.MaxFloat32), y)
   250  				require.Equal(t, math.MaxFloat64, z)
   251  				return 100
   252  			},
   253  			inputParams: []uint64{
   254  				api.EncodeExternref(tPtr),
   255  				math.MaxUint32,
   256  				math.MaxUint64,
   257  				api.EncodeF32(math.MaxFloat32),
   258  				api.EncodeF64(math.MaxFloat64),
   259  			},
   260  			expectedResults: []uint64{100},
   261  		},
   262  	}
   263  	for _, tt := range tests {
   264  		tc := tt
   265  
   266  		t.Run(tc.name, func(t *testing.T) {
   267  			_, _, code, err := parseGoReflectFunc(tc.input)
   268  			require.NoError(t, err)
   269  
   270  			resultLen := len(tc.expectedResults)
   271  			stackLen := len(tc.inputParams)
   272  			if resultLen > stackLen {
   273  				stackLen = resultLen
   274  			}
   275  			stack := make([]uint64, stackLen)
   276  			copy(stack, tc.inputParams)
   277  
   278  			switch code.GoFunc.(type) {
   279  			case api.GoFunction:
   280  				code.GoFunc.(api.GoFunction).Call(testCtx, stack)
   281  			case api.GoModuleFunction:
   282  				code.GoFunc.(api.GoModuleFunction).Call(testCtx, callCtx, stack)
   283  			default:
   284  				t.Fatal("unexpected type.")
   285  			}
   286  
   287  			var results []uint64
   288  			if resultLen > 0 {
   289  				results = stack[:resultLen]
   290  			}
   291  			require.Equal(t, tc.expectedResults, results)
   292  		})
   293  	}
   294  }