github.com/tetratelabs/wazero@v1.2.1/experimental/wazerotest/wazerotest.go (about)

     1  package wazerotest
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"math"
     9  	"reflect"
    10  	"strconv"
    11  	"sync"
    12  	"sync/atomic"
    13  
    14  	"github.com/tetratelabs/wazero/api"
    15  	"github.com/tetratelabs/wazero/internal/internalapi"
    16  	"github.com/tetratelabs/wazero/sys"
    17  )
    18  
    19  const (
    20  	exitStatusMarker = 1 << 63
    21  )
    22  
    23  // Module is an implementation of the api.Module interface, it represents a
    24  // WebAssembly module.
    25  type Module struct {
    26  	internalapi.WazeroOnlyType
    27  	exitStatus uint64
    28  
    29  	// The module name that will be returned by calling the Name method.
    30  	ModuleName string
    31  
    32  	// The list of functions of the module. Functions with a non-empty export
    33  	// names will be exported by the module.
    34  	Functions []*Function
    35  
    36  	// The list of globals of the module. Global
    37  	Globals []*Global
    38  
    39  	// The program memory. If non-nil, the memory is automatically exported as
    40  	// "memory".
    41  	ExportMemory *Memory
    42  
    43  	once                        sync.Once
    44  	exportedFunctions           map[string]api.Function
    45  	exportedFunctionDefinitions map[string]api.FunctionDefinition
    46  	exportedGlobals             map[string]api.Global
    47  	exportedMemoryDefinitions   map[string]api.MemoryDefinition
    48  }
    49  
    50  // NewModule constructs a Module object with the given memory and function list.
    51  func NewModule(memory *Memory, functions ...*Function) *Module {
    52  	return &Module{Functions: functions, ExportMemory: memory}
    53  }
    54  
    55  func (m *Module) String() string {
    56  	return "module[" + m.ModuleName + "]"
    57  }
    58  
    59  func (m *Module) Name() string {
    60  	return m.ModuleName
    61  }
    62  
    63  func (m *Module) Memory() api.Memory {
    64  	if m.ExportMemory != nil {
    65  		m.once.Do(m.initialize)
    66  		return m.ExportMemory
    67  	}
    68  	return nil
    69  }
    70  
    71  func (m *Module) ExportedFunction(name string) api.Function {
    72  	m.once.Do(m.initialize)
    73  	return m.exportedFunctions[name]
    74  }
    75  
    76  func (m *Module) ExportedFunctionDefinitions() map[string]api.FunctionDefinition {
    77  	m.once.Do(m.initialize)
    78  	return m.exportedFunctionDefinitions
    79  }
    80  
    81  func (m *Module) ExportedMemory(name string) api.Memory {
    82  	if m.ExportMemory != nil && name == "memory" {
    83  		m.once.Do(m.initialize)
    84  		return m.ExportMemory
    85  	}
    86  	return nil
    87  }
    88  
    89  func (m *Module) ExportedMemoryDefinitions() map[string]api.MemoryDefinition {
    90  	m.once.Do(m.initialize)
    91  	return m.exportedMemoryDefinitions
    92  }
    93  
    94  func (m *Module) ExportedGlobal(name string) api.Global {
    95  	m.once.Do(m.initialize)
    96  	return m.exportedGlobals[name]
    97  }
    98  
    99  func (m *Module) NumFunction() int {
   100  	return len(m.Functions)
   101  }
   102  
   103  func (m *Module) Function(i int) api.Function {
   104  	m.once.Do(m.initialize)
   105  	return m.Functions[i]
   106  }
   107  
   108  func (m *Module) NumGlobal() int {
   109  	return len(m.Globals)
   110  }
   111  
   112  func (m *Module) Global(i int) api.Global {
   113  	m.once.Do(m.initialize)
   114  	return m.Globals[i]
   115  }
   116  
   117  func (m *Module) Close(ctx context.Context) error {
   118  	return m.CloseWithExitCode(ctx, 0)
   119  }
   120  
   121  func (m *Module) CloseWithExitCode(ctx context.Context, exitCode uint32) error {
   122  	atomic.CompareAndSwapUint64(&m.exitStatus, 0, exitStatusMarker|uint64(exitCode))
   123  	return nil
   124  }
   125  
   126  func (m *Module) ExitStatus() (exitCode uint32, exited bool) {
   127  	exitStatus := atomic.LoadUint64(&m.exitStatus)
   128  	return uint32(exitStatus), exitStatus != 0
   129  }
   130  
   131  func (m *Module) initialize() {
   132  	m.exportedFunctions = make(map[string]api.Function)
   133  	m.exportedFunctionDefinitions = make(map[string]api.FunctionDefinition)
   134  	m.exportedGlobals = make(map[string]api.Global)
   135  	m.exportedMemoryDefinitions = make(map[string]api.MemoryDefinition)
   136  
   137  	for index, function := range m.Functions {
   138  		for _, exportName := range function.ExportNames {
   139  			m.exportedFunctions[exportName] = function
   140  			m.exportedFunctionDefinitions[exportName] = function.Definition()
   141  		}
   142  		function.module = m
   143  		function.index = index
   144  	}
   145  
   146  	for _, global := range m.Globals {
   147  		for _, exportName := range global.ExportNames {
   148  			m.exportedGlobals[exportName] = global
   149  		}
   150  	}
   151  
   152  	if m.ExportMemory != nil {
   153  		m.ExportMemory.module = m
   154  		m.exportedMemoryDefinitions["memory"] = m.ExportMemory.Definition()
   155  	}
   156  }
   157  
   158  // Global is an implementation of the api.Global interface, it represents a
   159  // global in a WebAssembly module.
   160  type Global struct {
   161  	internalapi.WazeroOnlyType
   162  
   163  	// Type of the global value, used to interpret bits of the Value field.
   164  	ValueType api.ValueType
   165  
   166  	// Value of the global packed in a 64 bits field.
   167  	Value uint64
   168  
   169  	// List of names that the globla is exported as.
   170  	ExportNames []string
   171  }
   172  
   173  func (g *Global) String() string {
   174  	switch g.ValueType {
   175  	case api.ValueTypeI32:
   176  		return strconv.FormatInt(int64(api.DecodeI32(g.Value)), 10)
   177  	case api.ValueTypeI64:
   178  		return strconv.FormatInt(int64(g.Value), 10)
   179  	case api.ValueTypeF32:
   180  		return strconv.FormatFloat(float64(api.DecodeF32(g.Value)), 'g', -1, 32)
   181  	case api.ValueTypeF64:
   182  		return strconv.FormatFloat(api.DecodeF64(g.Value), 'g', -1, 64)
   183  	default:
   184  		return "0x" + strconv.FormatUint(g.Value, 16)
   185  	}
   186  }
   187  
   188  func (g *Global) Type() api.ValueType {
   189  	return g.ValueType
   190  }
   191  
   192  func (g *Global) Get() uint64 {
   193  	return g.Value
   194  }
   195  
   196  func GlobalI32(value int32, export ...string) *Global {
   197  	return &Global{ValueType: api.ValueTypeI32, Value: api.EncodeI32(value), ExportNames: export}
   198  }
   199  
   200  func GlobalI64(value int64, export ...string) *Global {
   201  	return &Global{ValueType: api.ValueTypeI64, Value: api.EncodeI64(value), ExportNames: export}
   202  }
   203  
   204  func GlobalF32(value float32, export ...string) *Global {
   205  	return &Global{ValueType: api.ValueTypeF32, Value: api.EncodeF32(value), ExportNames: export}
   206  }
   207  
   208  func GlobalF64(value float64, export ...string) *Global {
   209  	return &Global{ValueType: api.ValueTypeF64, Value: api.EncodeF64(value), ExportNames: export}
   210  }
   211  
   212  // Function is an implementation of the api.Function interface, it represents
   213  // a function in a WebAssembly module.
   214  //
   215  // Until accessed through a Module's method, the function definition's
   216  // ModuleName method returns an empty string and its Index method returns 0.
   217  type Function struct {
   218  	internalapi.WazeroOnlyType
   219  
   220  	// GoModuleFunction may be set to a non-nil value to allow calling of the
   221  	// function via Call or CallWithStack.
   222  	//
   223  	// It is the user's responsibility to ensure that the signature of this
   224  	// implementation matches the ParamTypes and ResultTypes fields.
   225  	GoModuleFunction api.GoModuleFunction
   226  
   227  	// Type lists representing the function signature. Those fields should be
   228  	// set for the function to be properly constructed. The Function's Call
   229  	// and CallWithStack methods will error if those fields are nil.
   230  	ParamTypes  []api.ValueType
   231  	ResultTypes []api.ValueType
   232  
   233  	// Sets of names associated with the function. It is valid to leave those
   234  	// names empty, they are only used for debugging purposes.
   235  	FunctionName string
   236  	DebugName    string
   237  	ParamNames   []string
   238  	ResultNames  []string
   239  	ExportNames  []string
   240  
   241  	// Lazily initialized when accessed through the module.
   242  	module *Module
   243  	index  int
   244  }
   245  
   246  // NewFunction constructs a Function object from a Go function.
   247  //
   248  // The function fn must accept at least two arguments of type context.Context
   249  // and api.Module. Any other arguments and return values must be of type uint32,
   250  // uint64, int32, int64, float32, or float64. The call panics if fn is not a Go
   251  // functionn or has an unsupported signature.
   252  func NewFunction(fn any) *Function {
   253  	functionType := reflect.TypeOf(fn)
   254  	functionValue := reflect.ValueOf(fn)
   255  
   256  	paramTypes := make([]api.ValueType, functionType.NumIn()-2)
   257  	paramFuncs := make([]func(uint64) reflect.Value, len(paramTypes))
   258  
   259  	resultTypes := make([]api.ValueType, functionType.NumOut())
   260  	resultFuncs := make([]func(reflect.Value) uint64, len(resultTypes))
   261  
   262  	for i := range paramTypes {
   263  		var paramType api.ValueType
   264  		var paramFunc func(uint64) reflect.Value
   265  
   266  		switch functionType.In(i + 2).Kind() {
   267  		case reflect.Uint32:
   268  			paramType = api.ValueTypeI32
   269  			paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(api.DecodeU32(v)) }
   270  		case reflect.Uint64:
   271  			paramType = api.ValueTypeI64
   272  			paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(v) }
   273  		case reflect.Int32:
   274  			paramType = api.ValueTypeI32
   275  			paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(api.DecodeI32(v)) }
   276  		case reflect.Int64:
   277  			paramType = api.ValueTypeI64
   278  			paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(int64(v)) }
   279  		case reflect.Float32:
   280  			paramType = api.ValueTypeF32
   281  			paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(api.DecodeF32(v)) }
   282  		case reflect.Float64:
   283  			paramType = api.ValueTypeF64
   284  			paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(api.DecodeF64(v)) }
   285  		default:
   286  			panic("cannot construct wasm function from go function of type " + functionType.String())
   287  		}
   288  
   289  		paramTypes[i] = paramType
   290  		paramFuncs[i] = paramFunc
   291  	}
   292  
   293  	for i := range resultTypes {
   294  		var resultType api.ValueType
   295  		var resultFunc func(reflect.Value) uint64
   296  
   297  		switch functionType.Out(i).Kind() {
   298  		case reflect.Uint32:
   299  			resultType = api.ValueTypeI32
   300  			resultFunc = func(v reflect.Value) uint64 { return v.Uint() }
   301  		case reflect.Uint64:
   302  			resultType = api.ValueTypeI64
   303  			resultFunc = func(v reflect.Value) uint64 { return v.Uint() }
   304  		case reflect.Int32:
   305  			resultType = api.ValueTypeI32
   306  			resultFunc = func(v reflect.Value) uint64 { return api.EncodeI32(int32(v.Int())) }
   307  		case reflect.Int64:
   308  			resultType = api.ValueTypeI64
   309  			resultFunc = func(v reflect.Value) uint64 { return api.EncodeI64(v.Int()) }
   310  		case reflect.Float32:
   311  			resultType = api.ValueTypeF32
   312  			resultFunc = func(v reflect.Value) uint64 { return api.EncodeF32(float32(v.Float())) }
   313  		case reflect.Float64:
   314  			resultType = api.ValueTypeF64
   315  			resultFunc = func(v reflect.Value) uint64 { return api.EncodeF64(v.Float()) }
   316  		default:
   317  			panic("cannot construct wasm function from go function of type " + functionType.String())
   318  		}
   319  
   320  		resultTypes[i] = resultType
   321  		resultFuncs[i] = resultFunc
   322  	}
   323  
   324  	return &Function{
   325  		GoModuleFunction: api.GoModuleFunc(func(ctx context.Context, mod api.Module, stack []uint64) {
   326  			in := make([]reflect.Value, 2+len(paramFuncs))
   327  			in[0] = reflect.ValueOf(ctx)
   328  			in[1] = reflect.ValueOf(mod)
   329  			for i, param := range paramFuncs {
   330  				in[i+2] = param(stack[i])
   331  			}
   332  			out := functionValue.Call(in)
   333  			for i, result := range resultFuncs {
   334  				stack[i] = result(out[i])
   335  			}
   336  		}),
   337  		ParamTypes:  paramTypes,
   338  		ResultTypes: resultTypes,
   339  	}
   340  }
   341  
   342  var (
   343  	errMissingFunctionSignature      = errors.New("missing function signature")
   344  	errMissingFunctionModule         = errors.New("missing function module")
   345  	errMissingFunctionImplementation = errors.New("missing function implementation")
   346  )
   347  
   348  func (f *Function) Definition() api.FunctionDefinition {
   349  	return functionDefinition{function: f}
   350  }
   351  
   352  func (f *Function) Call(ctx context.Context, params ...uint64) ([]uint64, error) {
   353  	stackLen := len(f.ParamTypes)
   354  	if stackLen < len(f.ResultTypes) {
   355  		stackLen = len(f.ResultTypes)
   356  	}
   357  	stack := make([]uint64, stackLen)
   358  	copy(stack, params)
   359  	err := f.CallWithStack(ctx, stack)
   360  	if err != nil {
   361  		for i := range stack {
   362  			stack[i] = 0
   363  		}
   364  	}
   365  	return stack[:len(f.ResultTypes)], err
   366  }
   367  
   368  func (f *Function) CallWithStack(ctx context.Context, stack []uint64) error {
   369  	if f.ParamTypes == nil || f.ResultTypes == nil {
   370  		return errMissingFunctionSignature
   371  	}
   372  	if f.GoModuleFunction == nil {
   373  		return errMissingFunctionImplementation
   374  	}
   375  	if f.module == nil {
   376  		return errMissingFunctionModule
   377  	}
   378  	if exitCode, exited := f.module.ExitStatus(); exited {
   379  		return sys.NewExitError(exitCode)
   380  	}
   381  	f.GoModuleFunction.Call(ctx, f.module, stack)
   382  	return nil
   383  }
   384  
   385  type functionDefinition struct {
   386  	internalapi.WazeroOnlyType
   387  	function *Function
   388  }
   389  
   390  func (def functionDefinition) Name() string {
   391  	return def.function.FunctionName
   392  }
   393  
   394  func (def functionDefinition) DebugName() string {
   395  	if def.function.DebugName != "" {
   396  		return def.function.DebugName
   397  	}
   398  	return fmt.Sprintf("%s.$%d", def.ModuleName(), def.Index())
   399  }
   400  
   401  func (def functionDefinition) GoFunction() any {
   402  	return def.function.GoModuleFunction
   403  }
   404  
   405  func (def functionDefinition) ParamTypes() []api.ValueType {
   406  	return def.function.ParamTypes
   407  }
   408  
   409  func (def functionDefinition) ParamNames() []string {
   410  	return def.function.ParamNames
   411  }
   412  
   413  func (def functionDefinition) ResultTypes() []api.ValueType {
   414  	return def.function.ResultTypes
   415  }
   416  
   417  func (def functionDefinition) ResultNames() []string {
   418  	return def.function.ResultNames
   419  }
   420  
   421  func (def functionDefinition) ModuleName() string {
   422  	if def.function.module != nil {
   423  		return def.function.module.ModuleName
   424  	}
   425  	return ""
   426  }
   427  
   428  func (def functionDefinition) Index() uint32 {
   429  	return uint32(def.function.index)
   430  }
   431  
   432  func (def functionDefinition) Import() (moduleName, name string, isImport bool) {
   433  	return
   434  }
   435  
   436  func (def functionDefinition) ExportNames() []string {
   437  	return def.function.ExportNames
   438  }
   439  
   440  // Memory is an implementation of the api.Memory interface, representing the
   441  // memory of a WebAssembly module.
   442  type Memory struct {
   443  	internalapi.WazeroOnlyType
   444  
   445  	// Byte slices holding the memory pages.
   446  	//
   447  	// It is the user's repsonsibility to ensure that the length of this byte
   448  	// slice is a multiple of the page size.
   449  	Bytes []byte
   450  
   451  	// Min and max number of memory pages which may be held in this memory.
   452  	//
   453  	// Leaving Max to zero means no upper bound.
   454  	Min uint32
   455  	Max uint32
   456  
   457  	// Lazily initialized when accessed through the module.
   458  	module *Module
   459  }
   460  
   461  // NewMemory constructs a Memory object with a buffer of the given size, aligned
   462  // to the closest multiple of the page size.
   463  func NewMemory(size int) *Memory {
   464  	numPages := (size + (PageSize - 1)) / PageSize
   465  	return &Memory{
   466  		Bytes: make([]byte, numPages*PageSize),
   467  		Min:   uint32(numPages),
   468  	}
   469  }
   470  
   471  // NewFixedMemory constructs a Memory object of the given size. The returned
   472  // memory is configured with a max limit to prevent growing beyond its initial
   473  // size.
   474  func NewFixedMemory(size int) *Memory {
   475  	memory := NewMemory(size)
   476  	memory.Max = memory.Min
   477  	return memory
   478  }
   479  
   480  // The PageSize constant defines the size of WebAssembly memory pages in bytes.
   481  //
   482  // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#page-size
   483  const PageSize = 65536
   484  
   485  func (m *Memory) Definition() api.MemoryDefinition {
   486  	return memoryDefinition{memory: m}
   487  }
   488  
   489  func (m *Memory) Size() uint32 {
   490  	return uint32(len(m.Bytes))
   491  }
   492  
   493  func (m *Memory) Grow(deltaPages uint32) (previousPages uint32, ok bool) {
   494  	previousPages = uint32(len(m.Bytes) / PageSize)
   495  	numPages := previousPages + deltaPages
   496  	if m.Max != 0 && numPages > m.Max {
   497  		return previousPages, false
   498  	}
   499  	bytes := make([]byte, PageSize*numPages)
   500  	copy(bytes, m.Bytes)
   501  	m.Bytes = bytes
   502  	return previousPages, true
   503  }
   504  
   505  func (m *Memory) ReadByte(offset uint32) (byte, bool) {
   506  	if m.isOutOfRange(offset, 1) {
   507  		return 0, false
   508  	}
   509  	return m.Bytes[offset], true
   510  }
   511  
   512  func (m *Memory) ReadUint16Le(offset uint32) (uint16, bool) {
   513  	if m.isOutOfRange(offset, 2) {
   514  		return 0, false
   515  	}
   516  	return binary.LittleEndian.Uint16(m.Bytes[offset:]), true
   517  }
   518  
   519  func (m *Memory) ReadUint32Le(offset uint32) (uint32, bool) {
   520  	if m.isOutOfRange(offset, 4) {
   521  		return 0, false
   522  	}
   523  	return binary.LittleEndian.Uint32(m.Bytes[offset:]), true
   524  }
   525  
   526  func (m *Memory) ReadUint64Le(offset uint32) (uint64, bool) {
   527  	if m.isOutOfRange(offset, 8) {
   528  		return 0, false
   529  	}
   530  	return binary.LittleEndian.Uint64(m.Bytes[offset:]), true
   531  }
   532  
   533  func (m *Memory) ReadFloat32Le(offset uint32) (float32, bool) {
   534  	v, ok := m.ReadUint32Le(offset)
   535  	return math.Float32frombits(v), ok
   536  }
   537  
   538  func (m *Memory) ReadFloat64Le(offset uint32) (float64, bool) {
   539  	v, ok := m.ReadUint64Le(offset)
   540  	return math.Float64frombits(v), ok
   541  }
   542  
   543  func (m *Memory) Read(offset, length uint32) ([]byte, bool) {
   544  	if m.isOutOfRange(offset, length) {
   545  		return nil, false
   546  	}
   547  	return m.Bytes[offset : offset+length : offset+length], true
   548  }
   549  
   550  func (m *Memory) WriteByte(offset uint32, value byte) bool {
   551  	if m.isOutOfRange(offset, 1) {
   552  		return false
   553  	}
   554  	m.Bytes[offset] = value
   555  	return true
   556  }
   557  
   558  func (m *Memory) WriteUint16Le(offset uint32, value uint16) bool {
   559  	if m.isOutOfRange(offset, 2) {
   560  		return false
   561  	}
   562  	binary.LittleEndian.PutUint16(m.Bytes[offset:], value)
   563  	return true
   564  }
   565  
   566  func (m *Memory) WriteUint32Le(offset uint32, value uint32) bool {
   567  	if m.isOutOfRange(offset, 4) {
   568  		return false
   569  	}
   570  	binary.LittleEndian.PutUint32(m.Bytes[offset:], value)
   571  	return true
   572  }
   573  
   574  func (m *Memory) WriteUint64Le(offset uint32, value uint64) bool {
   575  	if m.isOutOfRange(offset, 4) {
   576  		return false
   577  	}
   578  	binary.LittleEndian.PutUint64(m.Bytes[offset:], value)
   579  	return true
   580  }
   581  
   582  func (m *Memory) WriteFloat32Le(offset uint32, value float32) bool {
   583  	return m.WriteUint32Le(offset, math.Float32bits(value))
   584  }
   585  
   586  func (m *Memory) WriteFloat64Le(offset uint32, value float64) bool {
   587  	return m.WriteUint64Le(offset, math.Float64bits(value))
   588  }
   589  
   590  func (m *Memory) Write(offset uint32, value []byte) bool {
   591  	if m.isOutOfRange(offset, uint32(len(value))) {
   592  		return false
   593  	}
   594  	copy(m.Bytes[offset:], value)
   595  	return true
   596  }
   597  
   598  func (m *Memory) WriteString(offset uint32, value string) bool {
   599  	if m.isOutOfRange(offset, uint32(len(value))) {
   600  		return false
   601  	}
   602  	copy(m.Bytes[offset:], value)
   603  	return true
   604  }
   605  
   606  func (m *Memory) isOutOfRange(offset, length uint32) bool {
   607  	size := m.Size()
   608  	return offset >= size || length > size || offset > (size-length)
   609  }
   610  
   611  type memoryDefinition struct {
   612  	internalapi.WazeroOnlyType
   613  	memory *Memory
   614  }
   615  
   616  func (def memoryDefinition) ModuleName() string {
   617  	if def.memory.module != nil {
   618  		return def.memory.module.ModuleName
   619  	}
   620  	return ""
   621  }
   622  
   623  func (def memoryDefinition) Index() uint32 {
   624  	return 0
   625  }
   626  
   627  func (def memoryDefinition) Import() (moduleName, name string, isImport bool) {
   628  	return
   629  }
   630  
   631  func (def memoryDefinition) ExportNames() []string {
   632  	if def.memory.module != nil {
   633  		return []string{"memory"}
   634  	}
   635  	return nil
   636  }
   637  
   638  func (def memoryDefinition) Min() uint32 {
   639  	return def.memory.Min
   640  }
   641  
   642  func (def memoryDefinition) Max() (uint32, bool) {
   643  	return def.memory.Max, def.memory.Max != 0
   644  }
   645  
   646  var (
   647  	_ api.Module   = (*Module)(nil)
   648  	_ api.Function = (*Function)(nil)
   649  	_ api.Global   = (*Global)(nil)
   650  )