github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/experimental/checkpoint_test.go (about)

     1  package experimental_test
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  
     7  	wazero "github.com/wasilibs/wazerox"
     8  	"github.com/wasilibs/wazerox/api"
     9  	"github.com/wasilibs/wazerox/experimental"
    10  	"github.com/wasilibs/wazerox/internal/testing/require"
    11  )
    12  
    13  func TestSnapshotNestedWasmInvocation(t *testing.T) {
    14  	ctx := context.Background()
    15  
    16  	rt := wazero.NewRuntime(ctx)
    17  	defer rt.Close(ctx)
    18  
    19  	sidechannel := 0
    20  
    21  	_, err := rt.NewHostModuleBuilder("example").
    22  		NewFunctionBuilder().
    23  		WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
    24  			defer func() {
    25  				sidechannel = 10
    26  			}()
    27  			snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
    28  			snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
    29  			idx := len(*snapshots)
    30  			*snapshots = append(*snapshots, snapshot)
    31  			ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
    32  			require.True(t, ok)
    33  
    34  			_, err := mod.ExportedFunction("restore").Call(ctx, uint64(snapshotPtr))
    35  			require.NoError(t, err)
    36  
    37  			return 2
    38  		}).
    39  		Export("snapshot").
    40  		NewFunctionBuilder().
    41  		WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
    42  			idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
    43  			require.True(t, ok)
    44  			snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
    45  			snapshot := (*snapshots)[idx]
    46  
    47  			snapshot.Restore([]uint64{12})
    48  		}).
    49  		Export("restore").
    50  		Instantiate(ctx)
    51  	require.NoError(t, err)
    52  
    53  	mod, err := rt.Instantiate(ctx, snapshotWasm)
    54  	require.NoError(t, err)
    55  
    56  	var snapshots []experimental.Snapshot
    57  	ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
    58  	ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})
    59  
    60  	snapshotPtr := uint64(0)
    61  	res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
    62  	require.NoError(t, err)
    63  	// return value from restore
    64  	require.Equal(t, uint64(12), res[0])
    65  	// Host function defers within the call stack work fine
    66  	require.Equal(t, 10, sidechannel)
    67  }
    68  
    69  func TestSnapshotMultipleWasmInvocations(t *testing.T) {
    70  	ctx := context.Background()
    71  
    72  	rt := wazero.NewRuntime(ctx)
    73  	defer rt.Close(ctx)
    74  
    75  	_, err := rt.NewHostModuleBuilder("example").
    76  		NewFunctionBuilder().
    77  		WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
    78  			snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
    79  			snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
    80  			idx := len(*snapshots)
    81  			*snapshots = append(*snapshots, snapshot)
    82  			ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
    83  			require.True(t, ok)
    84  
    85  			return 0
    86  		}).
    87  		Export("snapshot").
    88  		NewFunctionBuilder().
    89  		WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
    90  			idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
    91  			require.True(t, ok)
    92  			snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
    93  			snapshot := (*snapshots)[idx]
    94  
    95  			snapshot.Restore([]uint64{12})
    96  		}).
    97  		Export("restore").
    98  		Instantiate(ctx)
    99  	require.NoError(t, err)
   100  
   101  	mod, err := rt.Instantiate(ctx, snapshotWasm)
   102  	require.NoError(t, err)
   103  
   104  	var snapshots []experimental.Snapshot
   105  	ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
   106  	ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})
   107  
   108  	snapshotPtr := uint64(0)
   109  	res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
   110  	require.NoError(t, err)
   111  	// snapshot returned zero
   112  	require.Equal(t, uint64(0), res[0])
   113  
   114  	// Fails, snapshot and restore are called from different wasm invocations. Currently, this
   115  	// results in a panic.
   116  	err = require.CapturePanic(func() {
   117  		_, _ = mod.ExportedFunction("restore").Call(ctx, snapshotPtr)
   118  	})
   119  	require.EqualError(t, err, "unhandled snapshot restore, this generally indicates restore was called from a different "+
   120  		"exported function invocation than snapshot")
   121  }