tractor.dev/toolkit-go@v0.0.0-20241010005851-214d91207d07/duplex/fn/call.go (about)

     1  package fn
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  
     7  	"github.com/mitchellh/mapstructure"
     8  )
     9  
    10  var errorInterface = reflect.TypeOf((*error)(nil)).Elem()
    11  
    12  // Call wraps invoking a function via reflection, converting the arguments with
    13  // ArgsTo and the returns with ParseReturn. fn argument can be a function
    14  // or a reflect.Value for a function.
    15  func Call(fn any, args []any) (_ []any, err error) {
    16  	fnval := reflect.ValueOf(fn)
    17  	if rv, ok := fn.(reflect.Value); ok {
    18  		fnval = rv
    19  	}
    20  	fnParams, err := ArgsTo(fnval.Type(), args)
    21  	if err != nil {
    22  		return nil, err
    23  	}
    24  	fnReturn := fnval.Call(fnParams)
    25  	return ParseReturn(fnReturn)
    26  }
    27  
    28  // ArgsTo converts the arguments into `reflect.Value`s suitable to pass as
    29  // parameters to a function with the given type via reflection.
    30  func ArgsTo(fntyp reflect.Type, args []any) ([]reflect.Value, error) {
    31  	if len(args) != fntyp.NumIn() {
    32  		return nil, fmt.Errorf("fn: expected %d params, got %d", fntyp.NumIn(), len(args))
    33  	}
    34  	fnParams := make([]reflect.Value, len(args))
    35  	for idx, param := range args {
    36  		switch fntyp.In(idx).Kind() {
    37  		case reflect.Struct:
    38  			// decode to struct type using mapstructure
    39  			arg := reflect.New(fntyp.In(idx))
    40  			if err := mapstructure.Decode(param, arg.Interface()); err != nil {
    41  				return nil, fmt.Errorf("fn: mapstructure: %s", err.Error())
    42  			}
    43  			fnParams[idx] = ensureType(arg.Elem(), fntyp.In(idx))
    44  		case reflect.Slice:
    45  			rv := reflect.ValueOf(param)
    46  			// decode slice of structs to struct type using mapstructure
    47  			if fntyp.In(idx).Elem().Kind() == reflect.Struct {
    48  				nv := reflect.MakeSlice(fntyp.In(idx), rv.Len(), rv.Len())
    49  				for i := 0; i < rv.Len(); i++ {
    50  					ref := reflect.New(nv.Index(i).Type())
    51  					if err := mapstructure.Decode(rv.Index(i).Interface(), ref.Interface()); err != nil {
    52  						return nil, fmt.Errorf("fn: mapstructure: %s", err.Error())
    53  					}
    54  					nv.Index(i).Set(reflect.Indirect(ref))
    55  				}
    56  				rv = nv
    57  			}
    58  			fnParams[idx] = rv
    59  		default:
    60  			// if int is expected but got float64 assume json-like encoding and cast float to int
    61  			if fntyp.In(idx).Kind() == reflect.Int && reflect.TypeOf(param).Kind() == reflect.Float64 {
    62  				param = int(param.(float64))
    63  			}
    64  			fnParams[idx] = ensureType(reflect.ValueOf(param), fntyp.In(idx))
    65  		}
    66  	}
    67  	return fnParams, nil
    68  }
    69  
    70  // ParseReturn splits the results of reflect.Call() into the values, and
    71  // possibly an error.
    72  // If the last value is a non-nil error, this will return `nil, err`.
    73  // If the last value is a nil error it will be removed from the value list.
    74  // Any remaining values will be converted and returned as `any` typed values.
    75  func ParseReturn(ret []reflect.Value) ([]any, error) {
    76  	if len(ret) == 0 {
    77  		return nil, nil
    78  	}
    79  	last := ret[len(ret)-1]
    80  	if last.Type().Implements(errorInterface) {
    81  		if !last.IsNil() {
    82  			return nil, last.Interface().(error)
    83  		}
    84  		ret = ret[:len(ret)-1]
    85  	}
    86  	out := make([]any, len(ret))
    87  	for i, r := range ret {
    88  		out[i] = r.Interface()
    89  	}
    90  	return out, nil
    91  }
    92  
    93  // ensureType ensures a value is converted to the expected
    94  // defined type from a convertable underlying type
    95  func ensureType(v reflect.Value, t reflect.Type) reflect.Value {
    96  	if !v.IsValid() {
    97  		// handle nil values with zero value of expected type
    98  		return reflect.New(t).Elem()
    99  	}
   100  	nv := v
   101  	if v.Type().Kind() == reflect.Slice && v.Type().Elem() != t {
   102  		switch t.Kind() {
   103  		case reflect.Array:
   104  			nv = reflect.Indirect(reflect.New(t))
   105  			for i := 0; i < v.Len(); i++ {
   106  				vv := reflect.ValueOf(v.Index(i).Interface())
   107  				nv.Index(i).Set(vv.Convert(nv.Type().Elem()))
   108  			}
   109  		case reflect.Slice:
   110  			nv = reflect.MakeSlice(t, 0, 0)
   111  			for i := 0; i < v.Len(); i++ {
   112  				vv := reflect.ValueOf(v.Index(i).Interface())
   113  				nv = reflect.Append(nv, vv.Convert(nv.Type().Elem()))
   114  			}
   115  		default:
   116  			panic("unable to convert slice to non-array, non-slice type")
   117  		}
   118  	}
   119  	if v.Type() != t {
   120  		nv = nv.Convert(t)
   121  	}
   122  	return nv
   123  }