github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/experimental/features_example_test.go (about)

     1  package experimental_test
     2  
     3  import (
     4  	"context"
     5  	_ "embed"
     6  	"fmt"
     7  	"log"
     8  	"runtime"
     9  	"sync/atomic"
    10  
    11  	wazero "github.com/wasilibs/wazerox"
    12  	"github.com/wasilibs/wazerox/api"
    13  	"github.com/wasilibs/wazerox/experimental"
    14  	"github.com/wasilibs/wazerox/imports/wasi_snapshot_preview1"
    15  )
    16  
    17  // pthreadWasm was generated by the following:
    18  //
    19  //	docker run -it --rm -v `pwd`/testdata:/workspace ghcr.io/webassembly/wasi-sdk:wasi-sdk-20 sh -c '$CC -o /workspace/pthread.wasm /workspace/pthread.c --target=wasm32-wasi-threads --sysroot=/wasi-sysroot -pthread -mexec-model=reactor -Wl,--export=run -Wl,--export=get'
    20  //
    21  // TODO: Use zig cc instead of wasi-sdk to compile when it supports wasm32-wasi-threads
    22  // https://github.com/ziglang/zig/issues/15484
    23  //
    24  //go:embed testdata/pthread.wasm
    25  var pthreadWasm []byte
    26  
    27  //go:embed testdata/memory.wasm
    28  var memoryWasm []byte
    29  
    30  // This shows how to use a WebAssembly module compiled with the threads feature.
    31  func ExampleCoreFeaturesThreads() {
    32  	// Use a default context
    33  	ctx := context.Background()
    34  
    35  	// Threads support must be enabled explicitly in addition to standard V2 features.
    36  	cfg := wazero.NewRuntimeConfig().WithCoreFeatures(api.CoreFeaturesV2 | experimental.CoreFeaturesThreads)
    37  
    38  	r := wazero.NewRuntimeWithConfig(ctx, cfg)
    39  	defer r.Close(ctx)
    40  
    41  	wasmCompiled, err := r.CompileModule(ctx, pthreadWasm)
    42  	if err != nil {
    43  		log.Panicln(err)
    44  	}
    45  
    46  	// Because we are using wasi-sdk to compile the guest, we must initialize WASI.
    47  	wasi_snapshot_preview1.MustInstantiate(ctx, r)
    48  
    49  	if _, err := r.InstantiateWithConfig(ctx, memoryWasm, wazero.NewModuleConfig().WithName("env")); err != nil {
    50  		log.Panicln(err)
    51  	}
    52  
    53  	mod, err := r.InstantiateModule(ctx, wasmCompiled, wazero.NewModuleConfig().WithStartFunctions("_initialize"))
    54  	if err != nil {
    55  		log.Panicln(err)
    56  	}
    57  
    58  	// Channel to synchronize start of goroutines before running.
    59  	startCh := make(chan struct{})
    60  	// Channel to synchronize end of goroutines.
    61  	endCh := make(chan struct{})
    62  
    63  	// We start up 8 goroutines and run for 6000 iterations each. The count should reach
    64  	// 48000, at the end, but it would not if threads weren't working!
    65  	for i := 0; i < 8; i++ {
    66  		go func() {
    67  			defer func() { endCh <- struct{}{} }()
    68  			<-startCh
    69  
    70  			// We must instantiate a child per simultaneous thread. This should normally be pooled
    71  			// among arbitrary goroutine invocations.
    72  			child := createChildModule(r, mod, wasmCompiled)
    73  			fn := child.mod.ExportedFunction("run")
    74  			for i := 0; i < 6000; i++ {
    75  				_, err := fn.Call(ctx)
    76  				if err != nil {
    77  					log.Panicln(err)
    78  				}
    79  			}
    80  			runtime.KeepAlive(child)
    81  		}()
    82  	}
    83  	for i := 0; i < 8; i++ {
    84  		startCh <- struct{}{}
    85  	}
    86  	for i := 0; i < 8; i++ {
    87  		<-endCh
    88  	}
    89  
    90  	res, err := mod.ExportedFunction("get").Call(ctx)
    91  	if err != nil {
    92  		log.Panicln(err)
    93  	}
    94  	fmt.Println(res[0])
    95  	// Output: 48000
    96  }
    97  
    98  type childModule struct {
    99  	mod        api.Module
   100  	tlsBasePtr uint32
   101  }
   102  
   103  var prevTID uint32
   104  
   105  func createChildModule(rt wazero.Runtime, root api.Module, wasmCompiled wazero.CompiledModule) *childModule {
   106  	ctx := context.Background()
   107  
   108  	// Not executing function so is at end of stack
   109  	stackPointer := root.ExportedGlobal("__stack_pointer").Get()
   110  	tlsBase := root.ExportedGlobal("__tls_base").Get()
   111  
   112  	// Thread-local-storage for the main thread is from __tls_base to __stack_pointer
   113  	size := stackPointer - tlsBase
   114  
   115  	malloc := root.ExportedFunction("malloc")
   116  
   117  	// Allocate memory for the child thread stack
   118  	res, err := malloc.Call(ctx, size)
   119  	if err != nil {
   120  		panic(err)
   121  	}
   122  	ptr := uint32(res[0])
   123  
   124  	child, err := rt.InstantiateModule(ctx, wasmCompiled, wazero.NewModuleConfig().
   125  		// Don't need to execute start functions again in child, it crashes anyways because
   126  		// LLVM only allows calling them once.
   127  		WithStartFunctions())
   128  	if err != nil {
   129  		panic(err)
   130  	}
   131  	initTLS := child.ExportedFunction("__wasm_init_tls")
   132  	if _, err := initTLS.Call(ctx, uint64(ptr)); err != nil {
   133  		panic(err)
   134  	}
   135  
   136  	tid := atomic.AddUint32(&prevTID, 1)
   137  	root.Memory().WriteUint32Le(ptr, ptr)
   138  	root.Memory().WriteUint32Le(ptr+20, tid)
   139  	child.ExportedGlobal("__stack_pointer").(api.MutableGlobal).Set(uint64(ptr) + size)
   140  
   141  	ret := &childModule{
   142  		mod:        child,
   143  		tlsBasePtr: ptr,
   144  	}
   145  	runtime.SetFinalizer(ret, func(obj interface{}) {
   146  		cm := obj.(*childModule)
   147  		free := cm.mod.ExportedFunction("free")
   148  		// Ignore errors since runtime may have been closed before this is called.
   149  		_, _ = free.Call(ctx, uint64(cm.tlsBasePtr))
   150  		_ = cm.mod.Close(context.Background())
   151  	})
   152  	return ret
   153  }