github.com/machinefi/w3bstream@v1.6.5-rc9.0.20240426031326-b8c7c4876e72/pkg/modules/wasm/runtime/wasmtime/instance.go (about)

     1  package wasmtime
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"log/slog"
     7  	"reflect"
     8  	"sync"
     9  	"sync/atomic"
    10  
    11  	"github.com/bytecodealliance/wasmtime-go/v17"
    12  	"github.com/pkg/errors"
    13  
    14  	"github.com/machinefi/w3bstream/pkg/modules/wasm/abi/proxy"
    15  	"github.com/machinefi/w3bstream/pkg/modules/wasm/abi/types"
    16  	"github.com/machinefi/w3bstream/pkg/modules/wasm/host"
    17  )
    18  
    19  func NewWasmtimeInstance(vm *VM, mod *Module) types.Instance {
    20  	i := &Instance{
    21  		vm:  vm,
    22  		mod: mod,
    23  	}
    24  	i.stopCond = sync.NewCond(&i.locker)
    25  
    26  	return i
    27  }
    28  
    29  type Instance struct {
    30  	vm  *VM
    31  	mod *Module
    32  	ins *wasmtime.Instance
    33  	lnk *wasmtime.Linker
    34  
    35  	externs  []wasmtime.AsExtern
    36  	debug    *DwarfInfo
    37  	locker   sync.Mutex
    38  	started  atomic.Bool
    39  	refCount int
    40  	stopCond *sync.Cond
    41  
    42  	mem *wasmtime.Memory
    43  	fns sync.Map
    44  
    45  	data any
    46  }
    47  
    48  var _ types.Instance = (*Instance)(nil)
    49  
    50  func (i *Instance) ID() string {
    51  	return i.vm.id
    52  }
    53  
    54  func (i *Instance) register(namespace, fnName string, fn interface{}) error {
    55  	if namespace == "" || fnName == "" {
    56  		return ErrInvalidImportFunc
    57  	}
    58  
    59  	if fn == nil || reflect.ValueOf(fn).IsNil() || reflect.TypeOf(fn).Kind() != reflect.Func {
    60  		return ErrInvalidImportFunc
    61  	}
    62  
    63  	return i.lnk.FuncWrap(namespace, fnName, fn)
    64  
    65  	// fnType := reflect.TypeOf(fn)
    66  
    67  	// argsNum := fnType.NumIn()
    68  	// argKinds := make([]*wasmtime.ValType, argsNum)
    69  	// for i := 0; i < argsNum; i++ {
    70  	// 	argKinds[i] = convertFromGoType(fnType.In(i))
    71  	// }
    72  
    73  	// retsNum := fnType.NumOut()
    74  	// retKinds := make([]*wasmtime.ValType, retsNum)
    75  	// for i := 0; i < retsNum; i++ {
    76  	// 	retKinds[i] = convertFromGoType(fnType.Out(i))
    77  	// }
    78  
    79  	// return wasmtime.NewFunc(
    80  	// 	i.vm.store,
    81  	// 	wasmtime.NewFuncType(argKinds, retKinds),
    82  	// 	func(caller *wasmtime.Caller, args []wasmtime.Val) (rets []wasmtime.Val, trap *wasmtime.Trap) {
    83  	// 		if len(args) != len(argKinds) {
    84  	// 			return nil, wasmtime.NewTrap("wasmtime: unmatched input number of arguments")
    85  	// 		}
    86  
    87  	// 		for i := range args {
    88  	// 			if args[i].Kind() != argKinds[i].Kind() {
    89  	// 				return nil, wasmtime.NewTrap(fmt.Sprintf("wasmtime: unmatched input type of argument: %d", i))
    90  	// 			}
    91  	// 		}
    92  
    93  	// 		_args := make([]reflect.Value, len(args))
    94  	// 		for i := range args {
    95  	// 			_args[i] = convertToGoTypes(args[i])
    96  	// 		}
    97  
    98  	// 		defer func() {
    99  	// 			if r := recover(); r != nil {
   100  	// 				trap = wasmtime.NewTrap(fmt.Sprintf("wasmtime: call %s paniced, r: %v stack: %v", fnName, r, string(debug.Stack())))
   101  	// 				rets = nil
   102  	// 			}
   103  	// 		}()
   104  
   105  	// 		_rets := reflect.ValueOf(fn).Call(_args)
   106  	// 		rets = make([]wasmtime.Val, len(_rets))
   107  	// 		for i := range _rets {
   108  	// 			rets[i] = convertToWasmtimeVal(_rets[i])
   109  	// 		}
   110  	// 		return rets, nil
   111  
   112  	// 		// fn := caller.GetExport(fnName).Func()
   113  	// 		// result, err := fn.Call(i.vm.store, _args...)
   114  	// 		// if err != nil {
   115  	// 		// 	return nil, wasmtime.NewTrap(err.Error())
   116  	// 		// }
   117  	// 		// if result == nil {
   118  	// 		// 	return nil, nil
   119  	// 		// }
   120  	// 		// if v, ok := result.([]wasmtime.Val); ok {
   121  	// 		// 	return v, nil
   122  	// 		// }
   123  	// 		// return []wasmtime.Val{convertToWasmtimeVal(result)}, nil
   124  	// 	},
   125  	// ), nil
   126  }
   127  
   128  func (i *Instance) RegisterImports(name string) error {
   129  	if name != proxy.ABIName {
   130  		return errors.Wrap(ErrUnknownABIName, name)
   131  	}
   132  
   133  	hostFns, err := host.HostFunctions(i)
   134  	if err != nil {
   135  		return err
   136  	}
   137  
   138  	for fnName, fn := range hostFns {
   139  		if err = i.register(fn.Namespace, fnName, fn.Func); err != nil {
   140  			return err
   141  		}
   142  	}
   143  	return nil
   144  }
   145  
   146  func (i *Instance) Start() error {
   147  	i.lnk = wasmtime.NewLinker(i.vm.engine)
   148  	if err := i.lnk.DefineWasi(); err != nil {
   149  		slog.Error(err.Error())
   150  		return err
   151  	}
   152  
   153  	err := i.RegisterImports(proxy.ABIName)
   154  	if err != nil {
   155  		slog.Error(err.Error())
   156  		return err
   157  	}
   158  
   159  	i.ins, err = i.lnk.Instantiate(i.vm.store, i.mod.mod)
   160  	if err != nil {
   161  		slog.Error(err.Error())
   162  		return err
   163  	}
   164  
   165  	i.started.Store(true)
   166  	return nil
   167  }
   168  
   169  func (i *Instance) Stop() {
   170  	i.locker.Lock()
   171  	defer i.locker.Unlock()
   172  	for i.refCount > 0 {
   173  		i.stopCond.Wait()
   174  	}
   175  	if i.started.CompareAndSwap(true, false) {
   176  		// TODO destroy
   177  	}
   178  }
   179  
   180  func (i *Instance) Started() bool {
   181  	return i.started.Load()
   182  }
   183  
   184  func (i *Instance) Malloc(size int32) (int32, error) {
   185  	if !i.Started() {
   186  		return 0, ErrInstanceNotStarted
   187  	}
   188  
   189  	// alloc func was implemented in w3bstream-golang-sdk
   190  	fn, err := i.GetExportsFunc("alloc")
   191  	if err != nil {
   192  		fn, err = i.GetExportsFunc("malloc")
   193  		if err != nil {
   194  			return 0, err
   195  		}
   196  	}
   197  
   198  	addr, err := fn.Call(size)
   199  	if err != nil {
   200  		i.HandleError(err)
   201  		return 0, err
   202  	}
   203  	return addr.(int32), nil
   204  }
   205  
   206  func (i *Instance) GetExportsFunc(name string) (types.Function, error) {
   207  	if !i.Started() {
   208  		return nil, ErrInstanceNotStarted
   209  	}
   210  
   211  	if v, ok := i.fns.Load(name); ok {
   212  		return v.(*wasmtimeNativeFunction), nil
   213  	}
   214  
   215  	export := i.ins.GetExport(i.vm.store, name)
   216  	if export == nil {
   217  		return nil, errors.Wrap(ErrInvalidExportFunc, name)
   218  	}
   219  
   220  	f := export.Func()
   221  	if f == nil {
   222  		return nil, errors.Wrap(ErrInvalidExportFunc, name)
   223  	}
   224  	nf := newWasmtimeNativeFunction(i.vm.store, f)
   225  
   226  	i.fns.Store(name, nf)
   227  
   228  	return nf, nil
   229  }
   230  
   231  func (i *Instance) GetExportsMem(name string) ([]byte, error) {
   232  	if !i.Started() {
   233  		return nil, ErrInstanceNotStarted
   234  	}
   235  
   236  	if i.mem == nil {
   237  		exp := i.ins.GetExport(i.vm.store, name)
   238  		if exp == nil {
   239  			return nil, errors.Wrap(ErrInvalidExportMem, name)
   240  		}
   241  		m := exp.Memory()
   242  		if m == nil {
   243  			return nil, errors.Wrap(ErrInvalidExportMem, name)
   244  		}
   245  		i.mem = m
   246  	}
   247  
   248  	return i.mem.UnsafeData(i.vm.store), nil
   249  }
   250  
   251  func (i *Instance) GetMemory(addr, size int32) ([]byte, error) {
   252  	mem, err := i.GetExportsMem("memory")
   253  	if err != nil {
   254  		return nil, err
   255  	}
   256  
   257  	if checkIfOverflow(addr, size, mem) {
   258  		return nil, ErrMemAccessOverflow
   259  	}
   260  
   261  	return mem[addr : addr+size], nil
   262  }
   263  
   264  func (i *Instance) PutMemory(addr, size int32, data []byte) error {
   265  	mem, err := i.GetExportsMem("memory")
   266  	if err != nil {
   267  		return err
   268  	}
   269  
   270  	if need := int32(len(data)); need > size {
   271  		size = need
   272  	}
   273  
   274  	if checkIfOverflow(addr, size, mem) {
   275  		return ErrMemAccessOverflow
   276  	}
   277  
   278  	copy(mem[addr:], data[:size])
   279  	return nil
   280  }
   281  
   282  func (i *Instance) GetByte(addr int32) (byte, error) {
   283  	mem, err := i.GetExportsMem("memory")
   284  	if err != nil {
   285  		return 0, err
   286  	}
   287  
   288  	if checkIfOverflow(addr, 0, mem) {
   289  		return 0, ErrMemAccessOverflow
   290  	}
   291  
   292  	return mem[addr], nil
   293  }
   294  
   295  func (i *Instance) PutByte(addr int32, v byte) error {
   296  	mem, err := i.GetExportsMem("memory")
   297  	if err != nil {
   298  		return err
   299  	}
   300  
   301  	if checkIfOverflow(addr, 0, mem) {
   302  		return ErrMemAccessOverflow
   303  	}
   304  
   305  	mem[addr] = v
   306  	return nil
   307  }
   308  
   309  func (i *Instance) GetUint32(addr int32) (uint32, error) {
   310  	mem, err := i.GetExportsMem("memory")
   311  	if err != nil {
   312  		return 0, err
   313  	}
   314  
   315  	if checkIfOverflow(addr, 4, mem) {
   316  		return 0, ErrMemAccessOverflow
   317  	}
   318  
   319  	return binary.LittleEndian.Uint32(mem[addr:]), nil
   320  }
   321  
   322  func (i *Instance) PutUint32(addr int32, v uint32) error {
   323  	mem, err := i.GetExportsMem("memory")
   324  	if err != nil {
   325  		return err
   326  	}
   327  
   328  	if checkIfOverflow(addr, 4, mem) {
   329  		return ErrMemAccessOverflow
   330  	}
   331  
   332  	binary.LittleEndian.PutUint32(mem[addr:], v)
   333  	return nil
   334  }
   335  
   336  func (i *Instance) GetModule() types.Module {
   337  	return i.mod
   338  }
   339  
   340  func (i *Instance) GetUserdata() any {
   341  	return i.data
   342  }
   343  
   344  func (i *Instance) SetUserdata(data any) {
   345  	i.data = data
   346  }
   347  
   348  func (i *Instance) Lock(data any) {
   349  	i.locker.Lock()
   350  	i.data = data
   351  }
   352  
   353  func (i *Instance) Unlock() {
   354  	i.locker.Unlock()
   355  	i.data = nil
   356  }
   357  
   358  func (i *Instance) Acquire() bool {
   359  	i.locker.Lock()
   360  	defer i.locker.Unlock()
   361  
   362  	if !i.Started() {
   363  		return false
   364  	}
   365  
   366  	i.refCount++
   367  	return true
   368  }
   369  
   370  func (i *Instance) Release() {
   371  	i.locker.Lock()
   372  	defer i.locker.Unlock()
   373  	i.refCount--
   374  
   375  	if i.refCount <= 0 {
   376  		i.stopCond.Broadcast()
   377  	}
   378  	i.vm.store.GC()
   379  }
   380  
   381  func (i *Instance) Call(name string, args ...interface{}) (interface{}, uint64, error) {
   382  	if !i.Started() {
   383  		return nil, 0, ErrInstanceNotStarted
   384  	}
   385  
   386  	// if v, ok := i.fns.Load(name); ok {
   387  	// 	return v.(*wasmtimeNativeFunction), nil
   388  	// }
   389  
   390  	// export := i.ins.GetExport(i.vm.store, name)
   391  	// if export == nil {
   392  	// 	return nil, errors.Wrap(ErrInvalidExportFunc, name)
   393  	// }
   394  
   395  	// f := export.Func()
   396  	// if f == nil {
   397  	// 	return nil, errors.Wrap(ErrInvalidExportFunc, name)
   398  	// }
   399  
   400  	f := i.ins.GetFunc(i.vm.store, name)
   401  	if f == nil {
   402  		return nil, 0, errors.Wrap(ErrInvalidExportFunc, name)
   403  	}
   404  
   405  	// before, fuelset := i.vm.store.FuelConsumed()
   406  
   407  	ret, err := f.Call(i.vm.store, args...)
   408  	if err != nil {
   409  		i.HandleError(err)
   410  	}
   411  	// after, _ := i.vm.store.FuelConsumed()
   412  	// consumed := after - before
   413  	// slog.Info("check fuel after call "+name, "instance_id", i.vm.id, "before", before, "after", after, "consumed", consumed)
   414  	// if fuelset {
   415  	// 	if err := i.vm.store.AddFuel(consumed); err != nil {
   416  	// 		slog.Error("failed to add fuel", "instance_id", i.vm.id, "err", err)
   417  	// 	}
   418  	// }
   419  	return ret, 0, err
   420  }
   421  
   422  func (i *Instance) HandleError(err error) {
   423  	var trapErr *wasmtime.Trap
   424  	if !errors.As(err, &trapErr) {
   425  		return
   426  	}
   427  
   428  	frames := trapErr.Frames()
   429  	if frames == nil {
   430  		return
   431  	}
   432  
   433  	for _, f := range frames {
   434  		args := []any{
   435  			"func_index", f.FuncIndex(),
   436  			"func_offset", f.FuncOffset(),
   437  			"instance_id", i.vm.id,
   438  		}
   439  		pc := uint64(f.ModuleOffset())
   440  		if i.debug != nil {
   441  			if l := i.debug.SeekPC(pc); l != nil {
   442  				args = append(args,
   443  					"filename", l.File.Name,
   444  					"line", l.Line,
   445  				)
   446  			}
   447  		}
   448  		slog.Log(context.Background(), slog.LevelError, err.Error(), args...)
   449  	}
   450  }