github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/experimental/features_example_test.go (about) 1 package experimental_test 2 3 import ( 4 "context" 5 _ "embed" 6 "fmt" 7 "log" 8 "runtime" 9 "sync/atomic" 10 11 wazero "github.com/wasilibs/wazerox" 12 "github.com/wasilibs/wazerox/api" 13 "github.com/wasilibs/wazerox/experimental" 14 "github.com/wasilibs/wazerox/imports/wasi_snapshot_preview1" 15 ) 16 17 // pthreadWasm was generated by the following: 18 // 19 // docker run -it --rm -v `pwd`/testdata:/workspace ghcr.io/webassembly/wasi-sdk:wasi-sdk-20 sh -c '$CC -o /workspace/pthread.wasm /workspace/pthread.c --target=wasm32-wasi-threads --sysroot=/wasi-sysroot -pthread -mexec-model=reactor -Wl,--export=run -Wl,--export=get' 20 // 21 // TODO: Use zig cc instead of wasi-sdk to compile when it supports wasm32-wasi-threads 22 // https://github.com/ziglang/zig/issues/15484 23 // 24 //go:embed testdata/pthread.wasm 25 var pthreadWasm []byte 26 27 //go:embed testdata/memory.wasm 28 var memoryWasm []byte 29 30 // This shows how to use a WebAssembly module compiled with the threads feature. 31 func ExampleCoreFeaturesThreads() { 32 // Use a default context 33 ctx := context.Background() 34 35 // Threads support must be enabled explicitly in addition to standard V2 features. 36 cfg := wazero.NewRuntimeConfig().WithCoreFeatures(api.CoreFeaturesV2 | experimental.CoreFeaturesThreads) 37 38 r := wazero.NewRuntimeWithConfig(ctx, cfg) 39 defer r.Close(ctx) 40 41 wasmCompiled, err := r.CompileModule(ctx, pthreadWasm) 42 if err != nil { 43 log.Panicln(err) 44 } 45 46 // Because we are using wasi-sdk to compile the guest, we must initialize WASI. 47 wasi_snapshot_preview1.MustInstantiate(ctx, r) 48 49 if _, err := r.InstantiateWithConfig(ctx, memoryWasm, wazero.NewModuleConfig().WithName("env")); err != nil { 50 log.Panicln(err) 51 } 52 53 mod, err := r.InstantiateModule(ctx, wasmCompiled, wazero.NewModuleConfig().WithStartFunctions("_initialize")) 54 if err != nil { 55 log.Panicln(err) 56 } 57 58 // Channel to synchronize start of goroutines before running. 59 startCh := make(chan struct{}) 60 // Channel to synchronize end of goroutines. 61 endCh := make(chan struct{}) 62 63 // We start up 8 goroutines and run for 6000 iterations each. The count should reach 64 // 48000, at the end, but it would not if threads weren't working! 65 for i := 0; i < 8; i++ { 66 go func() { 67 defer func() { endCh <- struct{}{} }() 68 <-startCh 69 70 // We must instantiate a child per simultaneous thread. This should normally be pooled 71 // among arbitrary goroutine invocations. 72 child := createChildModule(r, mod, wasmCompiled) 73 fn := child.mod.ExportedFunction("run") 74 for i := 0; i < 6000; i++ { 75 _, err := fn.Call(ctx) 76 if err != nil { 77 log.Panicln(err) 78 } 79 } 80 runtime.KeepAlive(child) 81 }() 82 } 83 for i := 0; i < 8; i++ { 84 startCh <- struct{}{} 85 } 86 for i := 0; i < 8; i++ { 87 <-endCh 88 } 89 90 res, err := mod.ExportedFunction("get").Call(ctx) 91 if err != nil { 92 log.Panicln(err) 93 } 94 fmt.Println(res[0]) 95 // Output: 48000 96 } 97 98 type childModule struct { 99 mod api.Module 100 tlsBasePtr uint32 101 } 102 103 var prevTID uint32 104 105 func createChildModule(rt wazero.Runtime, root api.Module, wasmCompiled wazero.CompiledModule) *childModule { 106 ctx := context.Background() 107 108 // Not executing function so is at end of stack 109 stackPointer := root.ExportedGlobal("__stack_pointer").Get() 110 tlsBase := root.ExportedGlobal("__tls_base").Get() 111 112 // Thread-local-storage for the main thread is from __tls_base to __stack_pointer 113 size := stackPointer - tlsBase 114 115 malloc := root.ExportedFunction("malloc") 116 117 // Allocate memory for the child thread stack 118 res, err := malloc.Call(ctx, size) 119 if err != nil { 120 panic(err) 121 } 122 ptr := uint32(res[0]) 123 124 child, err := rt.InstantiateModule(ctx, wasmCompiled, wazero.NewModuleConfig(). 125 // Don't need to execute start functions again in child, it crashes anyways because 126 // LLVM only allows calling them once. 127 WithStartFunctions()) 128 if err != nil { 129 panic(err) 130 } 131 initTLS := child.ExportedFunction("__wasm_init_tls") 132 if _, err := initTLS.Call(ctx, uint64(ptr)); err != nil { 133 panic(err) 134 } 135 136 tid := atomic.AddUint32(&prevTID, 1) 137 root.Memory().WriteUint32Le(ptr, ptr) 138 root.Memory().WriteUint32Le(ptr+20, tid) 139 child.ExportedGlobal("__stack_pointer").(api.MutableGlobal).Set(uint64(ptr) + size) 140 141 ret := &childModule{ 142 mod: child, 143 tlsBasePtr: ptr, 144 } 145 runtime.SetFinalizer(ret, func(obj interface{}) { 146 cm := obj.(*childModule) 147 free := cm.mod.ExportedFunction("free") 148 // Ignore errors since runtime may have been closed before this is called. 149 _, _ = free.Call(ctx, uint64(cm.tlsBasePtr)) 150 _ = cm.mod.Close(context.Background()) 151 }) 152 return ret 153 }