github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/wazevo/hostmodule_test.go (about)

     1  package wazevo
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"unsafe"
     7  
     8  	"github.com/bananabytelabs/wazero/api"
     9  	"github.com/bananabytelabs/wazero/experimental"
    10  	"github.com/bananabytelabs/wazero/internal/testing/require"
    11  	"github.com/bananabytelabs/wazero/internal/wasm"
    12  )
    13  
    14  func Test_writeIface_readIface(t *testing.T) {
    15  	buf := make([]byte, 100)
    16  
    17  	var called bool
    18  	var goFn api.GoFunction = api.GoFunc(func(context.Context, []uint64) {
    19  		called = true
    20  	})
    21  	writeIface(goFn, buf)
    22  	got := readIface(buf).(api.GoFunction)
    23  	got.Call(context.Background(), nil)
    24  	require.True(t, called)
    25  }
    26  
    27  func Test_buildHostModuleOpaque(t *testing.T) {
    28  	for _, tc := range []struct {
    29  		name      string
    30  		m         *wasm.Module
    31  		listeners []experimental.FunctionListener
    32  	}{
    33  		{
    34  			name: "no listeners",
    35  			m: &wasm.Module{
    36  				CodeSection: []wasm.Code{
    37  					{GoFunc: api.GoFunc(func(context.Context, []uint64) {})},
    38  					{GoFunc: api.GoFunc(func(context.Context, []uint64) {})},
    39  				},
    40  			},
    41  		},
    42  		{
    43  			name: "listeners",
    44  			m: &wasm.Module{
    45  				CodeSection: []wasm.Code{
    46  					{GoFunc: api.GoFunc(func(context.Context, []uint64) {})},
    47  					{GoFunc: api.GoFunc(func(context.Context, []uint64) {})},
    48  					{GoFunc: api.GoFunc(func(context.Context, []uint64) {})},
    49  					{GoFunc: api.GoFunc(func(context.Context, []uint64) {})},
    50  				},
    51  			},
    52  			listeners: make([]experimental.FunctionListener, 50),
    53  		},
    54  	} {
    55  		tc := tc
    56  		t.Run(tc.name, func(t *testing.T) {
    57  			got := buildHostModuleOpaque(tc.m, tc.listeners)
    58  			opaque := uintptr(unsafe.Pointer(&got[0]))
    59  			require.Equal(t, tc.m, hostModuleFromOpaque(opaque))
    60  			if len(tc.listeners) > 0 {
    61  				require.Equal(t, tc.listeners, hostModuleListenersSliceFromOpaque(opaque))
    62  			}
    63  			for i, c := range tc.m.CodeSection {
    64  				require.Equal(t, c.GoFunc, hostModuleGoFuncFromOpaque[api.GoFunction](i, opaque))
    65  			}
    66  		})
    67  	}
    68  }