github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/experimental/listener_test.go (about) 1 package experimental_test 2 3 import ( 4 "context" 5 _ "embed" 6 "testing" 7 8 "github.com/bananabytelabs/wazero" 9 "github.com/bananabytelabs/wazero/api" 10 "github.com/bananabytelabs/wazero/experimental" 11 "github.com/bananabytelabs/wazero/experimental/wazerotest" 12 "github.com/bananabytelabs/wazero/internal/testing/binaryencoding" 13 "github.com/bananabytelabs/wazero/internal/testing/require" 14 "github.com/bananabytelabs/wazero/internal/wasm" 15 ) 16 17 // compile-time check to ensure recorder implements FunctionListenerFactory 18 var _ experimental.FunctionListenerFactory = &recorder{} 19 20 type recorder struct { 21 m map[string]struct{} 22 beforeNames []string 23 afterNames []string 24 abortNames []string 25 } 26 27 func (r *recorder) Before(ctx context.Context, _ api.Module, def api.FunctionDefinition, _ []uint64, _ experimental.StackIterator) { 28 r.beforeNames = append(r.beforeNames, def.DebugName()) 29 } 30 31 func (r *recorder) After(_ context.Context, _ api.Module, def api.FunctionDefinition, _ []uint64) { 32 r.afterNames = append(r.afterNames, def.DebugName()) 33 } 34 35 func (r *recorder) Abort(_ context.Context, _ api.Module, def api.FunctionDefinition, _ error) { 36 r.abortNames = append(r.abortNames, def.DebugName()) 37 } 38 39 func (r *recorder) NewFunctionListener(definition api.FunctionDefinition) experimental.FunctionListener { 40 r.m[definition.Name()] = struct{}{} 41 return r 42 } 43 44 func TestFunctionListenerFactory(t *testing.T) { 45 // Set context to one that has an experimental listener 46 factory := &recorder{m: map[string]struct{}{}} 47 ctx := context.WithValue(context.Background(), experimental.FunctionListenerFactoryKey{}, factory) 48 49 // Define a module with two functions 50 bin := binaryencoding.EncodeModule(&wasm.Module{ 51 TypeSection: []wasm.FunctionType{{}}, 52 ImportSection: []wasm.Import{{Module: "host"}}, 53 FunctionSection: []wasm.Index{0, 0}, 54 CodeSection: []wasm.Code{ 55 // fn1 56 {Body: []byte{ 57 // call fn2 twice 58 wasm.OpcodeCall, 2, 59 wasm.OpcodeCall, 2, 60 wasm.OpcodeEnd, 61 }}, 62 // fn2 63 {Body: []byte{wasm.OpcodeEnd}}, 64 }, 65 ExportSection: []wasm.Export{{Name: "fn1", Type: wasm.ExternTypeFunc, Index: 1}}, 66 NameSection: &wasm.NameSection{ 67 ModuleName: "test", 68 FunctionNames: wasm.NameMap{ 69 {Index: 0, Name: "import"}, // should skip for building listeners. 70 {Index: 1, Name: "fn1"}, 71 {Index: 2, Name: "fn2"}, 72 }, 73 }, 74 }) 75 76 r := wazero.NewRuntime(ctx) 77 defer r.Close(ctx) // This closes everything this Runtime created. 78 79 _, err := r.NewHostModuleBuilder("host").NewFunctionBuilder().WithFunc(func() {}).Export("").Instantiate(ctx) 80 require.NoError(t, err) 81 82 // Ensure the imported function was converted to a listener. 83 require.Equal(t, map[string]struct{}{"": {}}, factory.m) 84 85 compiled, err := r.CompileModule(ctx, bin) 86 require.NoError(t, err) 87 88 // Ensure each function was converted to a listener eagerly 89 require.Equal(t, map[string]struct{}{ 90 "": {}, 91 "fn1": {}, 92 "fn2": {}, 93 }, factory.m) 94 95 // Ensures that FunctionListener is a compile-time option, so passing 96 // context.Background here is ok to use listeners at runtime. 97 m, err := r.InstantiateModule(context.Background(), compiled, wazero.NewModuleConfig()) 98 require.NoError(t, err) 99 100 fn1 := m.ExportedFunction("fn1") 101 require.NotNil(t, fn1) 102 103 _, err = fn1.Call(context.Background()) 104 require.NoError(t, err) 105 106 require.Equal(t, []string{"test.fn1", "test.fn2", "test.fn2"}, factory.beforeNames) 107 require.Equal(t, []string{"test.fn2", "test.fn2", "test.fn1"}, factory.afterNames) // after is in the reverse order. 108 } 109 110 func TestMultiFunctionListenerFactory(t *testing.T) { 111 module := wazerotest.NewModule(nil, 112 wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}), 113 wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}), 114 wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}), 115 ) 116 117 stack := []experimental.StackFrame{ 118 {Function: module.Function(0), Params: []uint64{1}}, 119 {Function: module.Function(1), Params: []uint64{2}}, 120 {Function: module.Function(2), Params: []uint64{3}}, 121 } 122 123 n := 0 124 f := func(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, stackIterator experimental.StackIterator) { 125 n++ 126 i := 0 127 for stackIterator.Next() { 128 i++ 129 } 130 if i != 3 { 131 t.Errorf("wrong number of call frames: want=3 got=%d", i) 132 } 133 } 134 135 factory := experimental.MultiFunctionListenerFactory( 136 experimental.FunctionListenerFactoryFunc(func(def api.FunctionDefinition) experimental.FunctionListener { 137 return experimental.FunctionListenerFunc(f) 138 }), 139 experimental.FunctionListenerFactoryFunc(func(def api.FunctionDefinition) experimental.FunctionListener { 140 return experimental.FunctionListenerFunc(f) 141 }), 142 ) 143 144 function := module.Function(0).Definition() 145 listener := factory.NewFunctionListener(function) 146 listener.Before(context.Background(), module, function, stack[2].Params, experimental.NewStackIterator(stack...)) 147 148 if n != 2 { 149 t.Errorf("wrong number of function calls: want=2 got=%d", n) 150 } 151 } 152 153 func BenchmarkMultiFunctionListener(b *testing.B) { 154 module := wazerotest.NewModule(nil, 155 wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}), 156 wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}), 157 wazerotest.NewFunction(func(ctx context.Context, mod api.Module, value int32) {}), 158 ) 159 160 stack := []experimental.StackFrame{ 161 {Function: module.Function(0), Params: []uint64{1}}, 162 {Function: module.Function(1), Params: []uint64{2}}, 163 {Function: module.Function(2), Params: []uint64{3}}, 164 } 165 166 tests := []struct { 167 scenario string 168 function func(context.Context, api.Module, api.FunctionDefinition, []uint64, experimental.StackIterator) 169 }{ 170 { 171 scenario: "simple function listener", 172 function: func(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, stackIterator experimental.StackIterator) { 173 }, 174 }, 175 { 176 scenario: "stack iterator", 177 function: func(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, stackIterator experimental.StackIterator) { 178 for stackIterator.Next() { 179 } 180 }, 181 }, 182 } 183 184 for _, test := range tests { 185 b.Run(test.scenario, func(b *testing.B) { 186 factory := experimental.MultiFunctionListenerFactory( 187 experimental.FunctionListenerFactoryFunc(func(def api.FunctionDefinition) experimental.FunctionListener { 188 return experimental.FunctionListenerFunc(test.function) 189 }), 190 experimental.FunctionListenerFactoryFunc(func(def api.FunctionDefinition) experimental.FunctionListener { 191 return experimental.FunctionListenerFunc(test.function) 192 }), 193 ) 194 function := module.Function(0).Definition() 195 listener := factory.NewFunctionListener(function) 196 experimental.BenchmarkFunctionListener(b.N, module, stack, listener) 197 }) 198 } 199 }