github.com/tetratelabs/wazero@v1.2.1/experimental/listener_test.go (about)

     1  package experimental_test
     2  
     3  import (
     4  	"context"
     5  	_ "embed"
     6  	"testing"
     7  
     8  	"github.com/tetratelabs/wazero"
     9  	"github.com/tetratelabs/wazero/api"
    10  	"github.com/tetratelabs/wazero/experimental"
    11  	"github.com/tetratelabs/wazero/experimental/wazerotest"
    12  	"github.com/tetratelabs/wazero/internal/testing/binaryencoding"
    13  	"github.com/tetratelabs/wazero/internal/testing/require"
    14  	"github.com/tetratelabs/wazero/internal/wasm"
    15  )
    16  
    17  // compile-time check to ensure recorder implements FunctionListenerFactory
    18  var _ experimental.FunctionListenerFactory = &recorder{}
    19  
    20  type recorder struct {
    21  	m           map[string]struct{}
    22  	beforeNames []string
    23  	afterNames  []string
    24  	abortNames  []string
    25  }
    26  
    27  func (r *recorder) Before(ctx context.Context, _ api.Module, def api.FunctionDefinition, _ []uint64, _ experimental.StackIterator) {
    28  	r.beforeNames = append(r.beforeNames, def.DebugName())
    29  }
    30  
    31  func (r *recorder) After(_ context.Context, _ api.Module, def api.FunctionDefinition, _ []uint64) {
    32  	r.afterNames = append(r.afterNames, def.DebugName())
    33  }
    34  
    35  func (r *recorder) Abort(_ context.Context, _ api.Module, def api.FunctionDefinition, _ error) {
    36  	r.abortNames = append(r.abortNames, def.DebugName())
    37  }
    38  
    39  func (r *recorder) NewFunctionListener(definition api.FunctionDefinition) experimental.FunctionListener {
    40  	r.m[definition.Name()] = struct{}{}
    41  	return r
    42  }
    43  
    44  func TestFunctionListenerFactory(t *testing.T) {
    45  	// Set context to one that has an experimental listener
    46  	factory := &recorder{m: map[string]struct{}{}}
    47  	ctx := context.WithValue(context.Background(), experimental.FunctionListenerFactoryKey{}, factory)
    48  
    49  	// Define a module with two functions
    50  	bin := binaryencoding.EncodeModule(&wasm.Module{
    51  		TypeSection:     []wasm.FunctionType{{}},
    52  		ImportSection:   []wasm.Import{{Module: "host"}},
    53  		FunctionSection: []wasm.Index{0, 0},
    54  		CodeSection: []wasm.Code{
    55  			// fn1
    56  			{Body: []byte{
    57  				// call fn2 twice
    58  				wasm.OpcodeCall, 2,
    59  				wasm.OpcodeCall, 2,
    60  				wasm.OpcodeEnd,
    61  			}},
    62  			// fn2
    63  			{Body: []byte{wasm.OpcodeEnd}},
    64  		},
    65  		ExportSection: []wasm.Export{{Name: "fn1", Type: wasm.ExternTypeFunc, Index: 1}},
    66  		NameSection: &wasm.NameSection{
    67  			ModuleName: "test",
    68  			FunctionNames: wasm.NameMap{
    69  				{Index: 0, Name: "import"}, // should skip for building listeners.
    70  				{Index: 1, Name: "fn1"},
    71  				{Index: 2, Name: "fn2"},
    72  			},
    73  		},
    74  	})
    75  
    76  	r := wazero.NewRuntime(ctx)
    77  	defer r.Close(ctx) // This closes everything this Runtime created.
    78  
    79  	_, err := r.NewHostModuleBuilder("host").NewFunctionBuilder().WithFunc(func() {}).Export("").Instantiate(ctx)
    80  	require.NoError(t, err)
    81  
    82  	// Ensure the imported function was converted to a listener.
    83  	require.Equal(t, map[string]struct{}{"": {}}, factory.m)
    84  
    85  	compiled, err := r.CompileModule(ctx, bin)
    86  	require.NoError(t, err)
    87  
    88  	// Ensure each function was converted to a listener eagerly
    89  	require.Equal(t, map[string]struct{}{
    90  		"":    {},
    91  		"fn1": {},
    92  		"fn2": {},
    93  	}, factory.m)
    94  
    95  	// Ensures that FunctionListener is a compile-time option, so passing
    96  	// context.Background here is ok to use listeners at runtime.
    97  	m, err := r.InstantiateModule(context.Background(), compiled, wazero.NewModuleConfig())
    98  	require.NoError(t, err)
    99  
   100  	fn1 := m.ExportedFunction("fn1")
   101  	require.NotNil(t, fn1)
   102  
   103  	_, err = fn1.Call(context.Background())
   104  	require.NoError(t, err)
   105  
   106  	require.Equal(t, []string{"test.fn1", "test.fn2", "test.fn2"}, factory.beforeNames)
   107  	require.Equal(t, []string{"test.fn2", "test.fn2", "test.fn1"}, factory.afterNames) // after is in the reverse order.
   108  }
   109  
   110  func TestMultiFunctionListenerFactory(t *testing.T) {
   111  	module := wazerotest.NewModule(nil,
   112  		wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}),
   113  		wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}),
   114  		wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}),
   115  	)
   116  
   117  	stack := []experimental.StackFrame{
   118  		{Function: module.Function(0), Params: []uint64{1}},
   119  		{Function: module.Function(1), Params: []uint64{2}},
   120  		{Function: module.Function(2), Params: []uint64{3}},
   121  	}
   122  
   123  	n := 0
   124  	f := func(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, stackIterator experimental.StackIterator) {
   125  		n++
   126  		i := 0
   127  		for stackIterator.Next() {
   128  			var param uint64
   129  			switch i {
   130  			case 0:
   131  				param = 3
   132  			case 1:
   133  				param = 2
   134  			case 2:
   135  				param = 1
   136  			default:
   137  				t.Errorf("too many frames seen by stack iterator: %d", i)
   138  			}
   139  			if params := stackIterator.Parameters(); len(params) != 1 {
   140  				t.Errorf("wrong number of parameters in call frame %d: want=1 got=%d", i, len(params))
   141  			} else if params[0] != param {
   142  				t.Errorf("wrong parameter in call frame %d: want=%d got=%d", i, param, params[0])
   143  			}
   144  			i++
   145  		}
   146  		if i != 3 {
   147  			t.Errorf("wrong number of call frames: want=3 got=%d", i)
   148  		}
   149  	}
   150  
   151  	factory := experimental.MultiFunctionListenerFactory(
   152  		experimental.FunctionListenerFactoryFunc(func(def api.FunctionDefinition) experimental.FunctionListener {
   153  			return experimental.FunctionListenerFunc(f)
   154  		}),
   155  		experimental.FunctionListenerFactoryFunc(func(def api.FunctionDefinition) experimental.FunctionListener {
   156  			return experimental.FunctionListenerFunc(f)
   157  		}),
   158  	)
   159  
   160  	function := module.Function(0).Definition()
   161  	listener := factory.NewFunctionListener(function)
   162  	listener.Before(context.Background(), module, function, stack[2].Params, experimental.NewStackIterator(stack...))
   163  
   164  	if n != 2 {
   165  		t.Errorf("wrong number of function calls: want=2 got=%d", n)
   166  	}
   167  }
   168  
   169  func BenchmarkMultiFunctionListener(b *testing.B) {
   170  	module := wazerotest.NewModule(nil,
   171  		wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}),
   172  		wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}),
   173  		wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}),
   174  	)
   175  
   176  	stack := []experimental.StackFrame{
   177  		{Function: module.Function(0), Params: []uint64{1}},
   178  		{Function: module.Function(1), Params: []uint64{2}},
   179  		{Function: module.Function(2), Params: []uint64{3}},
   180  	}
   181  
   182  	tests := []struct {
   183  		scenario string
   184  		function func(context.Context, api.Module, api.FunctionDefinition, []uint64, experimental.StackIterator)
   185  	}{
   186  		{
   187  			scenario: "simple function listener",
   188  			function: func(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, stackIterator experimental.StackIterator) {
   189  			},
   190  		},
   191  		{
   192  			scenario: "stack iterator",
   193  			function: func(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, stackIterator experimental.StackIterator) {
   194  				for stackIterator.Next() {
   195  				}
   196  			},
   197  		},
   198  	}
   199  
   200  	for _, test := range tests {
   201  		b.Run(test.scenario, func(b *testing.B) {
   202  			factory := experimental.MultiFunctionListenerFactory(
   203  				experimental.FunctionListenerFactoryFunc(func(def api.FunctionDefinition) experimental.FunctionListener {
   204  					return experimental.FunctionListenerFunc(test.function)
   205  				}),
   206  				experimental.FunctionListenerFactoryFunc(func(def api.FunctionDefinition) experimental.FunctionListener {
   207  					return experimental.FunctionListenerFunc(test.function)
   208  				}),
   209  			)
   210  			function := module.Function(0).Definition()
   211  			listener := factory.NewFunctionListener(function)
   212  			experimental.BenchmarkFunctionListener(b.N, module, stack, listener)
   213  		})
   214  	}
   215  }