github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/wasm/gofunc.go (about)

     1  package wasm
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"math"
     9  	"reflect"
    10  
    11  	"github.com/bananabytelabs/wazero/api"
    12  )
    13  
    14  type paramsKind byte
    15  
    16  const (
    17  	paramsKindNoContext paramsKind = iota
    18  	paramsKindContext
    19  	paramsKindContextModule
    20  )
    21  
    22  // Below are reflection code to get the interface type used to parse functions and set values.
    23  
    24  var (
    25  	moduleType    = reflect.TypeOf((*api.Module)(nil)).Elem()
    26  	goContextType = reflect.TypeOf((*context.Context)(nil)).Elem()
    27  	errorType     = reflect.TypeOf((*error)(nil)).Elem()
    28  )
    29  
    30  // compile-time check to ensure reflectGoModuleFunction implements
    31  // api.GoModuleFunction.
    32  var _ api.GoModuleFunction = (*reflectGoModuleFunction)(nil)
    33  
    34  type reflectGoModuleFunction struct {
    35  	fn              *reflect.Value
    36  	params, results []ValueType
    37  }
    38  
    39  // Call implements the same method as documented on api.GoModuleFunction.
    40  func (f *reflectGoModuleFunction) Call(ctx context.Context, mod api.Module, stack []uint64) {
    41  	callGoFunc(ctx, mod, f.fn, stack)
    42  }
    43  
    44  // EqualTo is exposed for testing.
    45  func (f *reflectGoModuleFunction) EqualTo(that interface{}) bool {
    46  	if f2, ok := that.(*reflectGoModuleFunction); !ok {
    47  		return false
    48  	} else {
    49  		// TODO compare reflect pointers
    50  		return bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
    51  	}
    52  }
    53  
    54  // compile-time check to ensure reflectGoFunction implements api.GoFunction.
    55  var _ api.GoFunction = (*reflectGoFunction)(nil)
    56  
    57  type reflectGoFunction struct {
    58  	fn              *reflect.Value
    59  	pk              paramsKind
    60  	params, results []ValueType
    61  }
    62  
    63  // EqualTo is exposed for testing.
    64  func (f *reflectGoFunction) EqualTo(that interface{}) bool {
    65  	if f2, ok := that.(*reflectGoFunction); !ok {
    66  		return false
    67  	} else {
    68  		// TODO compare reflect pointers
    69  		return f.pk == f2.pk &&
    70  			bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
    71  	}
    72  }
    73  
    74  // Call implements the same method as documented on api.GoFunction.
    75  func (f *reflectGoFunction) Call(ctx context.Context, stack []uint64) {
    76  	if f.pk == paramsKindNoContext {
    77  		ctx = nil
    78  	}
    79  	callGoFunc(ctx, nil, f.fn, stack)
    80  }
    81  
    82  // callGoFunc executes the reflective function by converting params to Go
    83  // types. The results of the function call are converted back to api.ValueType.
    84  func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, stack []uint64) {
    85  	tp := fn.Type()
    86  
    87  	var in []reflect.Value
    88  	pLen := tp.NumIn()
    89  	if pLen != 0 {
    90  		in = make([]reflect.Value, pLen)
    91  
    92  		i := 0
    93  		if ctx != nil {
    94  			in[0] = newContextVal(ctx)
    95  			i++
    96  		}
    97  		if mod != nil {
    98  			in[1] = newModuleVal(mod)
    99  			i++
   100  		}
   101  
   102  		for j := 0; i < pLen; i++ {
   103  			next := tp.In(i)
   104  			val := reflect.New(next).Elem()
   105  			k := next.Kind()
   106  			raw := stack[j]
   107  			j++
   108  
   109  			switch k {
   110  			case reflect.Float32:
   111  				val.SetFloat(float64(math.Float32frombits(uint32(raw))))
   112  			case reflect.Float64:
   113  				val.SetFloat(math.Float64frombits(raw))
   114  			case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   115  				val.SetUint(raw)
   116  			case reflect.Int32, reflect.Int64:
   117  				val.SetInt(int64(raw))
   118  			default:
   119  				panic(fmt.Errorf("BUG: param[%d] has an invalid type: %v", i, k))
   120  			}
   121  			in[i] = val
   122  		}
   123  	}
   124  
   125  	// Execute the host function and push back the call result onto the stack.
   126  	for i, ret := range fn.Call(in) {
   127  		switch ret.Kind() {
   128  		case reflect.Float32:
   129  			stack[i] = uint64(math.Float32bits(float32(ret.Float())))
   130  		case reflect.Float64:
   131  			stack[i] = math.Float64bits(ret.Float())
   132  		case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   133  			stack[i] = ret.Uint()
   134  		case reflect.Int32, reflect.Int64:
   135  			stack[i] = uint64(ret.Int())
   136  		default:
   137  			panic(fmt.Errorf("BUG: result[%d] has an invalid type: %v", i, ret.Kind()))
   138  		}
   139  	}
   140  }
   141  
   142  func newContextVal(ctx context.Context) reflect.Value {
   143  	val := reflect.New(goContextType).Elem()
   144  	val.Set(reflect.ValueOf(ctx))
   145  	return val
   146  }
   147  
   148  func newModuleVal(m api.Module) reflect.Value {
   149  	val := reflect.New(moduleType).Elem()
   150  	val.Set(reflect.ValueOf(m))
   151  	return val
   152  }
   153  
   154  // MustParseGoReflectFuncCode parses Code from the go function or panics.
   155  //
   156  // Exposing this simplifies FunctionDefinition of host functions in built-in host
   157  // modules and tests.
   158  func MustParseGoReflectFuncCode(fn interface{}) Code {
   159  	_, _, code, err := parseGoReflectFunc(fn)
   160  	if err != nil {
   161  		panic(err)
   162  	}
   163  	return code
   164  }
   165  
   166  func parseGoReflectFunc(fn interface{}) (params, results []ValueType, code Code, err error) {
   167  	fnV := reflect.ValueOf(fn)
   168  	p := fnV.Type()
   169  
   170  	if fnV.Kind() != reflect.Func {
   171  		err = fmt.Errorf("kind != func: %s", fnV.Kind().String())
   172  		return
   173  	}
   174  
   175  	pk, kindErr := kind(p)
   176  	if kindErr != nil {
   177  		err = kindErr
   178  		return
   179  	}
   180  
   181  	pOffset := 0
   182  	switch pk {
   183  	case paramsKindNoContext:
   184  	case paramsKindContext:
   185  		pOffset = 1
   186  	case paramsKindContextModule:
   187  		pOffset = 2
   188  	}
   189  
   190  	pCount := p.NumIn() - pOffset
   191  	if pCount > 0 {
   192  		params = make([]ValueType, pCount)
   193  	}
   194  	for i := 0; i < len(params); i++ {
   195  		pI := p.In(i + pOffset)
   196  		if t, ok := getTypeOf(pI.Kind()); ok {
   197  			params[i] = t
   198  			continue
   199  		}
   200  
   201  		// Now, we will definitely err, decide which message is best
   202  		var arg0Type reflect.Type
   203  		if hc := pI.Implements(moduleType); hc {
   204  			arg0Type = moduleType
   205  		} else if gc := pI.Implements(goContextType); gc {
   206  			arg0Type = goContextType
   207  		}
   208  
   209  		if arg0Type != nil {
   210  			err = fmt.Errorf("param[%d] is a %s, which may be defined only once as param[0]", i+pOffset, arg0Type)
   211  		} else {
   212  			err = fmt.Errorf("param[%d] is unsupported: %s", i+pOffset, pI.Kind())
   213  		}
   214  		return
   215  	}
   216  
   217  	rCount := p.NumOut()
   218  	if rCount > 0 {
   219  		results = make([]ValueType, rCount)
   220  	}
   221  	for i := 0; i < len(results); i++ {
   222  		rI := p.Out(i)
   223  		if t, ok := getTypeOf(rI.Kind()); ok {
   224  			results[i] = t
   225  			continue
   226  		}
   227  
   228  		// Now, we will definitely err, decide which message is best
   229  		if rI.Implements(errorType) {
   230  			err = fmt.Errorf("result[%d] is an error, which is unsupported", i)
   231  		} else {
   232  			err = fmt.Errorf("result[%d] is unsupported: %s", i, rI.Kind())
   233  		}
   234  		return
   235  	}
   236  
   237  	code = Code{}
   238  	if pk == paramsKindContextModule {
   239  		code.GoFunc = &reflectGoModuleFunction{fn: &fnV, params: params, results: results}
   240  	} else {
   241  		code.GoFunc = &reflectGoFunction{pk: pk, fn: &fnV, params: params, results: results}
   242  	}
   243  	return
   244  }
   245  
   246  func kind(p reflect.Type) (paramsKind, error) {
   247  	pCount := p.NumIn()
   248  	if pCount > 0 && p.In(0).Kind() == reflect.Interface {
   249  		p0 := p.In(0)
   250  		if p0.Implements(moduleType) {
   251  			return 0, errors.New("invalid signature: api.Module parameter must be preceded by context.Context")
   252  		} else if p0.Implements(goContextType) {
   253  			if pCount >= 2 && p.In(1).Implements(moduleType) {
   254  				return paramsKindContextModule, nil
   255  			}
   256  			return paramsKindContext, nil
   257  		}
   258  	}
   259  	// Without context param allows portability with reflective runtimes.
   260  	// This allows people to more easily port to wazero.
   261  	return paramsKindNoContext, nil
   262  }
   263  
   264  func getTypeOf(kind reflect.Kind) (ValueType, bool) {
   265  	switch kind {
   266  	case reflect.Float64:
   267  		return ValueTypeF64, true
   268  	case reflect.Float32:
   269  		return ValueTypeF32, true
   270  	case reflect.Int32, reflect.Uint32:
   271  		return ValueTypeI32, true
   272  	case reflect.Int64, reflect.Uint64:
   273  		return ValueTypeI64, true
   274  	case reflect.Uintptr:
   275  		return ValueTypeExternref, true
   276  	default:
   277  		return 0x00, false
   278  	}
   279  }