github.com/machinefi/w3bstream@v1.6.5-rc9.0.20240426031326-b8c7c4876e72/pkg/modules/wasm/runtime/wasmtime/instance.go (about) 1 package wasmtime 2 3 import ( 4 "context" 5 "encoding/binary" 6 "log/slog" 7 "reflect" 8 "sync" 9 "sync/atomic" 10 11 "github.com/bytecodealliance/wasmtime-go/v17" 12 "github.com/pkg/errors" 13 14 "github.com/machinefi/w3bstream/pkg/modules/wasm/abi/proxy" 15 "github.com/machinefi/w3bstream/pkg/modules/wasm/abi/types" 16 "github.com/machinefi/w3bstream/pkg/modules/wasm/host" 17 ) 18 19 func NewWasmtimeInstance(vm *VM, mod *Module) types.Instance { 20 i := &Instance{ 21 vm: vm, 22 mod: mod, 23 } 24 i.stopCond = sync.NewCond(&i.locker) 25 26 return i 27 } 28 29 type Instance struct { 30 vm *VM 31 mod *Module 32 ins *wasmtime.Instance 33 lnk *wasmtime.Linker 34 35 externs []wasmtime.AsExtern 36 debug *DwarfInfo 37 locker sync.Mutex 38 started atomic.Bool 39 refCount int 40 stopCond *sync.Cond 41 42 mem *wasmtime.Memory 43 fns sync.Map 44 45 data any 46 } 47 48 var _ types.Instance = (*Instance)(nil) 49 50 func (i *Instance) ID() string { 51 return i.vm.id 52 } 53 54 func (i *Instance) register(namespace, fnName string, fn interface{}) error { 55 if namespace == "" || fnName == "" { 56 return ErrInvalidImportFunc 57 } 58 59 if fn == nil || reflect.ValueOf(fn).IsNil() || reflect.TypeOf(fn).Kind() != reflect.Func { 60 return ErrInvalidImportFunc 61 } 62 63 return i.lnk.FuncWrap(namespace, fnName, fn) 64 65 // fnType := reflect.TypeOf(fn) 66 67 // argsNum := fnType.NumIn() 68 // argKinds := make([]*wasmtime.ValType, argsNum) 69 // for i := 0; i < argsNum; i++ { 70 // argKinds[i] = convertFromGoType(fnType.In(i)) 71 // } 72 73 // retsNum := fnType.NumOut() 74 // retKinds := make([]*wasmtime.ValType, retsNum) 75 // for i := 0; i < retsNum; i++ { 76 // retKinds[i] = convertFromGoType(fnType.Out(i)) 77 // } 78 79 // return wasmtime.NewFunc( 80 // i.vm.store, 81 // wasmtime.NewFuncType(argKinds, retKinds), 82 // func(caller *wasmtime.Caller, args []wasmtime.Val) (rets []wasmtime.Val, trap *wasmtime.Trap) { 83 // if len(args) != len(argKinds) { 84 // return nil, wasmtime.NewTrap("wasmtime: unmatched input number of arguments") 85 // } 86 87 // for i := range args { 88 // if args[i].Kind() != argKinds[i].Kind() { 89 // return nil, wasmtime.NewTrap(fmt.Sprintf("wasmtime: unmatched input type of argument: %d", i)) 90 // } 91 // } 92 93 // _args := make([]reflect.Value, len(args)) 94 // for i := range args { 95 // _args[i] = convertToGoTypes(args[i]) 96 // } 97 98 // defer func() { 99 // if r := recover(); r != nil { 100 // trap = wasmtime.NewTrap(fmt.Sprintf("wasmtime: call %s paniced, r: %v stack: %v", fnName, r, string(debug.Stack()))) 101 // rets = nil 102 // } 103 // }() 104 105 // _rets := reflect.ValueOf(fn).Call(_args) 106 // rets = make([]wasmtime.Val, len(_rets)) 107 // for i := range _rets { 108 // rets[i] = convertToWasmtimeVal(_rets[i]) 109 // } 110 // return rets, nil 111 112 // // fn := caller.GetExport(fnName).Func() 113 // // result, err := fn.Call(i.vm.store, _args...) 114 // // if err != nil { 115 // // return nil, wasmtime.NewTrap(err.Error()) 116 // // } 117 // // if result == nil { 118 // // return nil, nil 119 // // } 120 // // if v, ok := result.([]wasmtime.Val); ok { 121 // // return v, nil 122 // // } 123 // // return []wasmtime.Val{convertToWasmtimeVal(result)}, nil 124 // }, 125 // ), nil 126 } 127 128 func (i *Instance) RegisterImports(name string) error { 129 if name != proxy.ABIName { 130 return errors.Wrap(ErrUnknownABIName, name) 131 } 132 133 hostFns, err := host.HostFunctions(i) 134 if err != nil { 135 return err 136 } 137 138 for fnName, fn := range hostFns { 139 if err = i.register(fn.Namespace, fnName, fn.Func); err != nil { 140 return err 141 } 142 } 143 return nil 144 } 145 146 func (i *Instance) Start() error { 147 i.lnk = wasmtime.NewLinker(i.vm.engine) 148 if err := i.lnk.DefineWasi(); err != nil { 149 slog.Error(err.Error()) 150 return err 151 } 152 153 err := i.RegisterImports(proxy.ABIName) 154 if err != nil { 155 slog.Error(err.Error()) 156 return err 157 } 158 159 i.ins, err = i.lnk.Instantiate(i.vm.store, i.mod.mod) 160 if err != nil { 161 slog.Error(err.Error()) 162 return err 163 } 164 165 i.started.Store(true) 166 return nil 167 } 168 169 func (i *Instance) Stop() { 170 i.locker.Lock() 171 defer i.locker.Unlock() 172 for i.refCount > 0 { 173 i.stopCond.Wait() 174 } 175 if i.started.CompareAndSwap(true, false) { 176 // TODO destroy 177 } 178 } 179 180 func (i *Instance) Started() bool { 181 return i.started.Load() 182 } 183 184 func (i *Instance) Malloc(size int32) (int32, error) { 185 if !i.Started() { 186 return 0, ErrInstanceNotStarted 187 } 188 189 // alloc func was implemented in w3bstream-golang-sdk 190 fn, err := i.GetExportsFunc("alloc") 191 if err != nil { 192 fn, err = i.GetExportsFunc("malloc") 193 if err != nil { 194 return 0, err 195 } 196 } 197 198 addr, err := fn.Call(size) 199 if err != nil { 200 i.HandleError(err) 201 return 0, err 202 } 203 return addr.(int32), nil 204 } 205 206 func (i *Instance) GetExportsFunc(name string) (types.Function, error) { 207 if !i.Started() { 208 return nil, ErrInstanceNotStarted 209 } 210 211 if v, ok := i.fns.Load(name); ok { 212 return v.(*wasmtimeNativeFunction), nil 213 } 214 215 export := i.ins.GetExport(i.vm.store, name) 216 if export == nil { 217 return nil, errors.Wrap(ErrInvalidExportFunc, name) 218 } 219 220 f := export.Func() 221 if f == nil { 222 return nil, errors.Wrap(ErrInvalidExportFunc, name) 223 } 224 nf := newWasmtimeNativeFunction(i.vm.store, f) 225 226 i.fns.Store(name, nf) 227 228 return nf, nil 229 } 230 231 func (i *Instance) GetExportsMem(name string) ([]byte, error) { 232 if !i.Started() { 233 return nil, ErrInstanceNotStarted 234 } 235 236 if i.mem == nil { 237 exp := i.ins.GetExport(i.vm.store, name) 238 if exp == nil { 239 return nil, errors.Wrap(ErrInvalidExportMem, name) 240 } 241 m := exp.Memory() 242 if m == nil { 243 return nil, errors.Wrap(ErrInvalidExportMem, name) 244 } 245 i.mem = m 246 } 247 248 return i.mem.UnsafeData(i.vm.store), nil 249 } 250 251 func (i *Instance) GetMemory(addr, size int32) ([]byte, error) { 252 mem, err := i.GetExportsMem("memory") 253 if err != nil { 254 return nil, err 255 } 256 257 if checkIfOverflow(addr, size, mem) { 258 return nil, ErrMemAccessOverflow 259 } 260 261 return mem[addr : addr+size], nil 262 } 263 264 func (i *Instance) PutMemory(addr, size int32, data []byte) error { 265 mem, err := i.GetExportsMem("memory") 266 if err != nil { 267 return err 268 } 269 270 if need := int32(len(data)); need > size { 271 size = need 272 } 273 274 if checkIfOverflow(addr, size, mem) { 275 return ErrMemAccessOverflow 276 } 277 278 copy(mem[addr:], data[:size]) 279 return nil 280 } 281 282 func (i *Instance) GetByte(addr int32) (byte, error) { 283 mem, err := i.GetExportsMem("memory") 284 if err != nil { 285 return 0, err 286 } 287 288 if checkIfOverflow(addr, 0, mem) { 289 return 0, ErrMemAccessOverflow 290 } 291 292 return mem[addr], nil 293 } 294 295 func (i *Instance) PutByte(addr int32, v byte) error { 296 mem, err := i.GetExportsMem("memory") 297 if err != nil { 298 return err 299 } 300 301 if checkIfOverflow(addr, 0, mem) { 302 return ErrMemAccessOverflow 303 } 304 305 mem[addr] = v 306 return nil 307 } 308 309 func (i *Instance) GetUint32(addr int32) (uint32, error) { 310 mem, err := i.GetExportsMem("memory") 311 if err != nil { 312 return 0, err 313 } 314 315 if checkIfOverflow(addr, 4, mem) { 316 return 0, ErrMemAccessOverflow 317 } 318 319 return binary.LittleEndian.Uint32(mem[addr:]), nil 320 } 321 322 func (i *Instance) PutUint32(addr int32, v uint32) error { 323 mem, err := i.GetExportsMem("memory") 324 if err != nil { 325 return err 326 } 327 328 if checkIfOverflow(addr, 4, mem) { 329 return ErrMemAccessOverflow 330 } 331 332 binary.LittleEndian.PutUint32(mem[addr:], v) 333 return nil 334 } 335 336 func (i *Instance) GetModule() types.Module { 337 return i.mod 338 } 339 340 func (i *Instance) GetUserdata() any { 341 return i.data 342 } 343 344 func (i *Instance) SetUserdata(data any) { 345 i.data = data 346 } 347 348 func (i *Instance) Lock(data any) { 349 i.locker.Lock() 350 i.data = data 351 } 352 353 func (i *Instance) Unlock() { 354 i.locker.Unlock() 355 i.data = nil 356 } 357 358 func (i *Instance) Acquire() bool { 359 i.locker.Lock() 360 defer i.locker.Unlock() 361 362 if !i.Started() { 363 return false 364 } 365 366 i.refCount++ 367 return true 368 } 369 370 func (i *Instance) Release() { 371 i.locker.Lock() 372 defer i.locker.Unlock() 373 i.refCount-- 374 375 if i.refCount <= 0 { 376 i.stopCond.Broadcast() 377 } 378 i.vm.store.GC() 379 } 380 381 func (i *Instance) Call(name string, args ...interface{}) (interface{}, uint64, error) { 382 if !i.Started() { 383 return nil, 0, ErrInstanceNotStarted 384 } 385 386 // if v, ok := i.fns.Load(name); ok { 387 // return v.(*wasmtimeNativeFunction), nil 388 // } 389 390 // export := i.ins.GetExport(i.vm.store, name) 391 // if export == nil { 392 // return nil, errors.Wrap(ErrInvalidExportFunc, name) 393 // } 394 395 // f := export.Func() 396 // if f == nil { 397 // return nil, errors.Wrap(ErrInvalidExportFunc, name) 398 // } 399 400 f := i.ins.GetFunc(i.vm.store, name) 401 if f == nil { 402 return nil, 0, errors.Wrap(ErrInvalidExportFunc, name) 403 } 404 405 // before, fuelset := i.vm.store.FuelConsumed() 406 407 ret, err := f.Call(i.vm.store, args...) 408 if err != nil { 409 i.HandleError(err) 410 } 411 // after, _ := i.vm.store.FuelConsumed() 412 // consumed := after - before 413 // slog.Info("check fuel after call "+name, "instance_id", i.vm.id, "before", before, "after", after, "consumed", consumed) 414 // if fuelset { 415 // if err := i.vm.store.AddFuel(consumed); err != nil { 416 // slog.Error("failed to add fuel", "instance_id", i.vm.id, "err", err) 417 // } 418 // } 419 return ret, 0, err 420 } 421 422 func (i *Instance) HandleError(err error) { 423 var trapErr *wasmtime.Trap 424 if !errors.As(err, &trapErr) { 425 return 426 } 427 428 frames := trapErr.Frames() 429 if frames == nil { 430 return 431 } 432 433 for _, f := range frames { 434 args := []any{ 435 "func_index", f.FuncIndex(), 436 "func_offset", f.FuncOffset(), 437 "instance_id", i.vm.id, 438 } 439 pc := uint64(f.ModuleOffset()) 440 if i.debug != nil { 441 if l := i.debug.SeekPC(pc); l != nil { 442 args = append(args, 443 "filename", l.File.Name, 444 "line", l.Line, 445 ) 446 } 447 } 448 slog.Log(context.Background(), slog.LevelError, err.Error(), args...) 449 } 450 }