github.com/taubyte/vm-wasm-utils@v1.0.2/gofunc.go (about)

     1  package wasm
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math"
     7  	"reflect"
     8  
     9  	"github.com/tetratelabs/wazero/api"
    10  )
    11  
    12  // FunctionKind identifies the type of function that can be called.
    13  type FunctionKind byte
    14  
    15  const (
    16  	// FunctionKindWasm is not a Go function: it is implemented in Wasm.
    17  	FunctionKindWasm FunctionKind = iota
    18  	// FunctionKindGoNoContext is a function implemented in Go, with a signature matching FunctionType.
    19  	FunctionKindGoNoContext
    20  	// FunctionKindGoContext is a function implemented in Go, with a signature matching FunctionType, except arg zero is
    21  	// a context.Context.
    22  	FunctionKindGoContext
    23  	// FunctionKindGoModule is a function implemented in Go, with a signature matching FunctionType, except arg
    24  	// zero is an api.Module.
    25  	FunctionKindGoModule
    26  	// FunctionKindGoContextModule is a function implemented in Go, with a signature matching FunctionType, except arg
    27  	// zero is a context.Context and arg one is an api.Module.
    28  	FunctionKindGoContextModule
    29  )
    30  
    31  // Below are reflection code to get the interface type used to parse functions and set values.
    32  
    33  var moduleType = reflect.TypeOf((*api.Module)(nil)).Elem()
    34  var goContextType = reflect.TypeOf((*context.Context)(nil)).Elem()
    35  var errorType = reflect.TypeOf((*error)(nil)).Elem()
    36  
    37  // PopGoFuncParams pops the correct number of parameters off the stack into a parameter slice for use in CallGoFunc
    38  //
    39  // For example, if the host function F requires the (x1 uint32, x2 float32) parameters, and
    40  // the stack is [..., A, B], then the function is called as F(A, B) where A and B are interpreted
    41  // as uint32 and float32 respectively.
    42  func PopGoFuncParams(f *FunctionInstance, popParam func() uint64) []uint64 {
    43  	// First, determine how many values we need to pop
    44  	paramCount := f.GoFunc.Type().NumIn()
    45  	switch f.Kind {
    46  	case FunctionKindGoNoContext:
    47  	case FunctionKindGoContextModule:
    48  		paramCount -= 2
    49  	default:
    50  		paramCount--
    51  	}
    52  
    53  	return PopValues(paramCount, popParam)
    54  }
    55  
    56  // PopValues pops api.ValueType values from the stack and returns them in reverse order.
    57  //
    58  // Note: the popper intentionally doesn't return bool or error because the caller's stack depth is trusted.
    59  func PopValues(count int, popper func() uint64) []uint64 {
    60  	if count == 0 {
    61  		return nil
    62  	}
    63  	params := make([]uint64, count)
    64  	for i := count - 1; i >= 0; i-- {
    65  		params[i] = popper()
    66  	}
    67  	return params
    68  }
    69  
    70  // CallGoFunc executes the FunctionInstance.GoFunc by converting params to Go types. The results of the function call
    71  // are converted back to api.ValueType.
    72  //
    73  // * callCtx is passed to the host function as a first argument.
    74  //
    75  // Note: ctx must use the caller's memory, which might be different from the defining module on an imported function.
    76  func CallGoFunc(ctx context.Context, callCtx *CallContext, f *FunctionInstance, params []uint64) []uint64 {
    77  	tp := f.GoFunc.Type()
    78  
    79  	var in []reflect.Value
    80  	if tp.NumIn() != 0 {
    81  		in = make([]reflect.Value, tp.NumIn())
    82  
    83  		i := 0
    84  		switch f.Kind {
    85  		case FunctionKindGoContext:
    86  			in[0] = newContextVal(ctx)
    87  			i = 1
    88  		case FunctionKindGoModule:
    89  			in[0] = newModuleVal(callCtx)
    90  			i = 1
    91  		case FunctionKindGoContextModule:
    92  			in[0] = newContextVal(ctx)
    93  			in[1] = newModuleVal(callCtx)
    94  			i = 2
    95  		}
    96  
    97  		for _, raw := range params {
    98  			val := reflect.New(tp.In(i)).Elem()
    99  			k := tp.In(i).Kind()
   100  			switch k {
   101  			case reflect.Float32:
   102  				val.SetFloat(float64(math.Float32frombits(uint32(raw))))
   103  			case reflect.Float64:
   104  				val.SetFloat(math.Float64frombits(raw))
   105  			case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   106  				val.SetUint(raw)
   107  			case reflect.Int32, reflect.Int64:
   108  				val.SetInt(int64(raw))
   109  			default:
   110  				panic(fmt.Errorf("BUG: param[%d] has an invalid type: %v", i, k))
   111  			}
   112  			in[i] = val
   113  			i++
   114  		}
   115  	}
   116  
   117  	// Execute the host function and push back the call result onto the stack.
   118  	var results []uint64
   119  	if tp.NumOut() > 0 {
   120  		results = make([]uint64, 0, tp.NumOut())
   121  	}
   122  	for i, ret := range f.GoFunc.Call(in) {
   123  		switch ret.Kind() {
   124  		case reflect.Float32:
   125  			results = append(results, uint64(math.Float32bits(float32(ret.Float()))))
   126  		case reflect.Float64:
   127  			results = append(results, math.Float64bits(ret.Float()))
   128  		case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   129  			results = append(results, ret.Uint())
   130  		case reflect.Int32, reflect.Int64:
   131  			results = append(results, uint64(ret.Int()))
   132  		default:
   133  			panic(fmt.Errorf("BUG: result[%d] has an invalid type: %v", i, ret.Kind()))
   134  		}
   135  	}
   136  	return results
   137  }
   138  
   139  func newContextVal(ctx context.Context) reflect.Value {
   140  	val := reflect.New(goContextType).Elem()
   141  	val.Set(reflect.ValueOf(ctx))
   142  	return val
   143  }
   144  
   145  func newModuleVal(m api.Module) reflect.Value {
   146  	val := reflect.New(moduleType).Elem()
   147  	val.Set(reflect.ValueOf(m))
   148  	return val
   149  }
   150  
   151  // MustParseGoFuncCode parses Code from the go function or panics.
   152  //
   153  // Exposing this simplifies definition of host functions in built-in host
   154  // modules and tests.
   155  func MustParseGoFuncCode(fn interface{}) *Code {
   156  	_, _, code, err := parseGoFunc(fn)
   157  	if err != nil {
   158  		panic(err)
   159  	}
   160  	return code
   161  }
   162  
   163  func parseGoFunc(fn interface{}) (params, results []ValueType, code *Code, err error) {
   164  	fnV := reflect.ValueOf(fn)
   165  	p := fnV.Type()
   166  
   167  	if fnV.Kind() != reflect.Func {
   168  		err = fmt.Errorf("kind != func: %s", fnV.Kind().String())
   169  		return
   170  	}
   171  
   172  	fk := kind(p)
   173  	code = &Code{IsHostFunction: true, Kind: fk, GoFunc: &fnV}
   174  
   175  	pOffset := 0
   176  	switch fk {
   177  	case FunctionKindGoNoContext:
   178  	case FunctionKindGoContextModule:
   179  		pOffset = 2
   180  	default:
   181  		pOffset = 1
   182  	}
   183  
   184  	pCount := p.NumIn() - pOffset
   185  	if pCount > 0 {
   186  		params = make([]ValueType, pCount)
   187  	}
   188  	for i := 0; i < len(params); i++ {
   189  		pI := p.In(i + pOffset)
   190  		if t, ok := getTypeOf(pI.Kind()); ok {
   191  			params[i] = t
   192  			continue
   193  		}
   194  
   195  		// Now, we will definitely err, decide which message is best
   196  		var arg0Type reflect.Type
   197  		if hc := pI.Implements(moduleType); hc {
   198  			arg0Type = moduleType
   199  		} else if gc := pI.Implements(goContextType); gc {
   200  			arg0Type = goContextType
   201  		}
   202  
   203  		if arg0Type != nil {
   204  			err = fmt.Errorf("param[%d] is a %s, which may be defined only once as param[0]", i+pOffset, arg0Type)
   205  		} else {
   206  			err = fmt.Errorf("param[%d] is unsupported: %s", i+pOffset, pI.Kind())
   207  		}
   208  		return
   209  	}
   210  
   211  	rCount := p.NumOut()
   212  	if rCount > 0 {
   213  		results = make([]ValueType, rCount)
   214  	}
   215  	for i := 0; i < len(results); i++ {
   216  		rI := p.Out(i)
   217  		if t, ok := getTypeOf(rI.Kind()); ok {
   218  			results[i] = t
   219  			continue
   220  		}
   221  
   222  		// Now, we will definitely err, decide which message is best
   223  		if rI.Implements(errorType) {
   224  			err = fmt.Errorf("result[%d] is an error, which is unsupported", i)
   225  		} else {
   226  			err = fmt.Errorf("result[%d] is unsupported: %s", i, rI.Kind())
   227  		}
   228  		return
   229  	}
   230  	return
   231  }
   232  
   233  func kind(p reflect.Type) FunctionKind {
   234  	pCount := p.NumIn()
   235  	if pCount > 0 && p.In(0).Kind() == reflect.Interface {
   236  		p0 := p.In(0)
   237  		if p0.Implements(moduleType) {
   238  			return FunctionKindGoModule
   239  		} else if p0.Implements(goContextType) {
   240  			if pCount >= 2 && p.In(1).Implements(moduleType) {
   241  				return FunctionKindGoContextModule
   242  			}
   243  			return FunctionKindGoContext
   244  		}
   245  	}
   246  	return FunctionKindGoNoContext
   247  }
   248  
   249  func getTypeOf(kind reflect.Kind) (ValueType, bool) {
   250  	switch kind {
   251  	case reflect.Float64:
   252  		return ValueTypeF64, true
   253  	case reflect.Float32:
   254  		return ValueTypeF32, true
   255  	case reflect.Int32, reflect.Uint32:
   256  		return ValueTypeI32, true
   257  	case reflect.Int64, reflect.Uint64:
   258  		return ValueTypeI64, true
   259  	case reflect.Uintptr:
   260  		return ValueTypeExternref, true
   261  	default:
   262  		return 0x00, false
   263  	}
   264  }