wa-lang.org/wazero@v1.0.2/internal/wasm/gofunc.go (about)

     1  package wasm
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"math"
     9  	"reflect"
    10  
    11  	"wa-lang.org/wazero/api"
    12  )
    13  
    14  type paramsKind byte
    15  
    16  const (
    17  	paramsKindNoContext paramsKind = iota
    18  	paramsKindContext
    19  	paramsKindContextModule
    20  )
    21  
    22  // Below are reflection code to get the interface type used to parse functions and set values.
    23  
    24  var (
    25  	moduleType    = reflect.TypeOf((*api.Module)(nil)).Elem()
    26  	goContextType = reflect.TypeOf((*context.Context)(nil)).Elem()
    27  	errorType     = reflect.TypeOf((*error)(nil)).Elem()
    28  )
    29  
    30  // compile-time check to ensure reflectGoModuleFunction implements
    31  // api.GoModuleFunction.
    32  var _ api.GoModuleFunction = (*reflectGoModuleFunction)(nil)
    33  
    34  type reflectGoModuleFunction struct {
    35  	fn              *reflect.Value
    36  	params, results []ValueType
    37  }
    38  
    39  // Call implements the same method as documented on api.GoModuleFunction.
    40  func (f *reflectGoModuleFunction) Call(ctx context.Context, mod api.Module, stack []uint64) {
    41  	callGoFunc(ctx, mod, f.fn, stack)
    42  }
    43  
    44  // EqualTo is exposed for testing.
    45  func (f *reflectGoModuleFunction) EqualTo(that interface{}) bool {
    46  	if f2, ok := that.(*reflectGoModuleFunction); !ok {
    47  		return false
    48  	} else {
    49  		// TODO compare reflect pointers
    50  		return bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
    51  	}
    52  }
    53  
    54  // compile-time check to ensure reflectGoFunction implements api.GoFunction.
    55  var _ api.GoFunction = (*reflectGoFunction)(nil)
    56  
    57  type reflectGoFunction struct {
    58  	fn              *reflect.Value
    59  	pk              paramsKind
    60  	params, results []ValueType
    61  }
    62  
    63  // EqualTo is exposed for testing.
    64  func (f *reflectGoFunction) EqualTo(that interface{}) bool {
    65  	if f2, ok := that.(*reflectGoFunction); !ok {
    66  		return false
    67  	} else {
    68  		// TODO compare reflect pointers
    69  		return f.pk == f2.pk &&
    70  			bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
    71  	}
    72  }
    73  
    74  // Call implements the same method as documented on api.GoFunction.
    75  func (f *reflectGoFunction) Call(ctx context.Context, stack []uint64) {
    76  	if f.pk == paramsKindNoContext {
    77  		ctx = nil
    78  	}
    79  	callGoFunc(ctx, nil, f.fn, stack)
    80  }
    81  
    82  // PopValues pops the specified number of api.ValueType parameters off the
    83  // stack into a parameter slice for use in api.GoFunction or api.GoModuleFunction.
    84  //
    85  // For example, if the host function F requires the (x1 uint32, x2 float32)
    86  // parameters, and the stack is [..., A, B], then the function is called as
    87  // F(A, B) where A and B are interpreted as uint32 and float32 respectively.
    88  //
    89  // Note: the popper intentionally doesn't return bool or error because the
    90  // caller's stack depth is trusted.
    91  func PopValues(count int, popper func() uint64) []uint64 {
    92  	if count == 0 {
    93  		return nil
    94  	}
    95  	params := make([]uint64, count)
    96  	for i := count - 1; i >= 0; i-- {
    97  		params[i] = popper()
    98  	}
    99  	return params
   100  }
   101  
   102  // callGoFunc executes the reflective function by converting params to Go
   103  // types. The results of the function call are converted back to api.ValueType.
   104  func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, stack []uint64) {
   105  	tp := fn.Type()
   106  
   107  	var in []reflect.Value
   108  	pLen := tp.NumIn()
   109  	if pLen != 0 {
   110  		in = make([]reflect.Value, pLen)
   111  
   112  		i := 0
   113  		if ctx != nil {
   114  			in[0] = newContextVal(ctx)
   115  			i++
   116  		}
   117  		if mod != nil {
   118  			in[1] = newModuleVal(mod)
   119  			i++
   120  		}
   121  
   122  		for j := 0; i < pLen; i++ {
   123  			next := tp.In(i)
   124  			val := reflect.New(next).Elem()
   125  			k := next.Kind()
   126  			raw := stack[j]
   127  			j++
   128  
   129  			switch k {
   130  			case reflect.Float32:
   131  				val.SetFloat(float64(math.Float32frombits(uint32(raw))))
   132  			case reflect.Float64:
   133  				val.SetFloat(math.Float64frombits(raw))
   134  			case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   135  				val.SetUint(raw)
   136  			case reflect.Int32, reflect.Int64:
   137  				val.SetInt(int64(raw))
   138  			default:
   139  				panic(fmt.Errorf("BUG: param[%d] has an invalid type: %v", i, k))
   140  			}
   141  			in[i] = val
   142  		}
   143  	}
   144  
   145  	// Execute the host function and push back the call result onto the stack.
   146  	for i, ret := range fn.Call(in) {
   147  		switch ret.Kind() {
   148  		case reflect.Float32:
   149  			stack[i] = uint64(math.Float32bits(float32(ret.Float())))
   150  		case reflect.Float64:
   151  			stack[i] = math.Float64bits(ret.Float())
   152  		case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   153  			stack[i] = ret.Uint()
   154  		case reflect.Int32, reflect.Int64:
   155  			stack[i] = uint64(ret.Int())
   156  		default:
   157  			panic(fmt.Errorf("BUG: result[%d] has an invalid type: %v", i, ret.Kind()))
   158  		}
   159  	}
   160  }
   161  
   162  func newContextVal(ctx context.Context) reflect.Value {
   163  	val := reflect.New(goContextType).Elem()
   164  	val.Set(reflect.ValueOf(ctx))
   165  	return val
   166  }
   167  
   168  func newModuleVal(m api.Module) reflect.Value {
   169  	val := reflect.New(moduleType).Elem()
   170  	val.Set(reflect.ValueOf(m))
   171  	return val
   172  }
   173  
   174  // MustParseGoReflectFuncCode parses Code from the go function or panics.
   175  //
   176  // Exposing this simplifies FunctionDefinition of host functions in built-in host
   177  // modules and tests.
   178  func MustParseGoReflectFuncCode(fn interface{}) *Code {
   179  	_, _, code, err := parseGoReflectFunc(fn)
   180  	if err != nil {
   181  		panic(err)
   182  	}
   183  	return code
   184  }
   185  
   186  func parseGoReflectFunc(fn interface{}) (params, results []ValueType, code *Code, err error) {
   187  	fnV := reflect.ValueOf(fn)
   188  	p := fnV.Type()
   189  
   190  	if fnV.Kind() != reflect.Func {
   191  		err = fmt.Errorf("kind != func: %s", fnV.Kind().String())
   192  		return
   193  	}
   194  
   195  	pk, kindErr := kind(p)
   196  	if kindErr != nil {
   197  		err = kindErr
   198  		return
   199  	}
   200  
   201  	pOffset := 0
   202  	switch pk {
   203  	case paramsKindNoContext:
   204  	case paramsKindContext:
   205  		pOffset = 1
   206  	case paramsKindContextModule:
   207  		pOffset = 2
   208  	}
   209  
   210  	pCount := p.NumIn() - pOffset
   211  	if pCount > 0 {
   212  		params = make([]ValueType, pCount)
   213  	}
   214  	for i := 0; i < len(params); i++ {
   215  		pI := p.In(i + pOffset)
   216  		if t, ok := getTypeOf(pI.Kind()); ok {
   217  			params[i] = t
   218  			continue
   219  		}
   220  
   221  		// Now, we will definitely err, decide which message is best
   222  		var arg0Type reflect.Type
   223  		if hc := pI.Implements(moduleType); hc {
   224  			arg0Type = moduleType
   225  		} else if gc := pI.Implements(goContextType); gc {
   226  			arg0Type = goContextType
   227  		}
   228  
   229  		if arg0Type != nil {
   230  			err = fmt.Errorf("param[%d] is a %s, which may be defined only once as param[0]", i+pOffset, arg0Type)
   231  		} else {
   232  			err = fmt.Errorf("param[%d] is unsupported: %s", i+pOffset, pI.Kind())
   233  		}
   234  		return
   235  	}
   236  
   237  	rCount := p.NumOut()
   238  	if rCount > 0 {
   239  		results = make([]ValueType, rCount)
   240  	}
   241  	for i := 0; i < len(results); i++ {
   242  		rI := p.Out(i)
   243  		if t, ok := getTypeOf(rI.Kind()); ok {
   244  			results[i] = t
   245  			continue
   246  		}
   247  
   248  		// Now, we will definitely err, decide which message is best
   249  		if rI.Implements(errorType) {
   250  			err = fmt.Errorf("result[%d] is an error, which is unsupported", i)
   251  		} else {
   252  			err = fmt.Errorf("result[%d] is unsupported: %s", i, rI.Kind())
   253  		}
   254  		return
   255  	}
   256  
   257  	code = &Code{IsHostFunction: true}
   258  	if pk == paramsKindContextModule {
   259  		code.GoFunc = &reflectGoModuleFunction{fn: &fnV, params: params, results: results}
   260  	} else {
   261  		code.GoFunc = &reflectGoFunction{pk: pk, fn: &fnV, params: params, results: results}
   262  	}
   263  	return
   264  }
   265  
   266  func kind(p reflect.Type) (paramsKind, error) {
   267  	pCount := p.NumIn()
   268  	if pCount > 0 && p.In(0).Kind() == reflect.Interface {
   269  		p0 := p.In(0)
   270  		if p0.Implements(moduleType) {
   271  			return 0, errors.New("invalid signature: api.Module parameter must be preceded by context.Context")
   272  		} else if p0.Implements(goContextType) {
   273  			if pCount >= 2 && p.In(1).Implements(moduleType) {
   274  				return paramsKindContextModule, nil
   275  			}
   276  			return paramsKindContext, nil
   277  		}
   278  	}
   279  	// Without context param allows portability with reflective runtimes.
   280  	// This allows people to more easily port to wazero.
   281  	return paramsKindNoContext, nil
   282  }
   283  
   284  func getTypeOf(kind reflect.Kind) (ValueType, bool) {
   285  	switch kind {
   286  	case reflect.Float64:
   287  		return ValueTypeF64, true
   288  	case reflect.Float32:
   289  		return ValueTypeF32, true
   290  	case reflect.Int32, reflect.Uint32:
   291  		return ValueTypeI32, true
   292  	case reflect.Int64, reflect.Uint64:
   293  		return ValueTypeI64, true
   294  	case reflect.Uintptr:
   295  		return ValueTypeExternref, true
   296  	default:
   297  		return 0x00, false
   298  	}
   299  }