tractor.dev/toolkit-go@v0.0.0-20241010005851-214d91207d07/duplex/fn/call.go (about) 1 package fn 2 3 import ( 4 "fmt" 5 "reflect" 6 7 "github.com/mitchellh/mapstructure" 8 ) 9 10 var errorInterface = reflect.TypeOf((*error)(nil)).Elem() 11 12 // Call wraps invoking a function via reflection, converting the arguments with 13 // ArgsTo and the returns with ParseReturn. fn argument can be a function 14 // or a reflect.Value for a function. 15 func Call(fn any, args []any) (_ []any, err error) { 16 fnval := reflect.ValueOf(fn) 17 if rv, ok := fn.(reflect.Value); ok { 18 fnval = rv 19 } 20 fnParams, err := ArgsTo(fnval.Type(), args) 21 if err != nil { 22 return nil, err 23 } 24 fnReturn := fnval.Call(fnParams) 25 return ParseReturn(fnReturn) 26 } 27 28 // ArgsTo converts the arguments into `reflect.Value`s suitable to pass as 29 // parameters to a function with the given type via reflection. 30 func ArgsTo(fntyp reflect.Type, args []any) ([]reflect.Value, error) { 31 if len(args) != fntyp.NumIn() { 32 return nil, fmt.Errorf("fn: expected %d params, got %d", fntyp.NumIn(), len(args)) 33 } 34 fnParams := make([]reflect.Value, len(args)) 35 for idx, param := range args { 36 switch fntyp.In(idx).Kind() { 37 case reflect.Struct: 38 // decode to struct type using mapstructure 39 arg := reflect.New(fntyp.In(idx)) 40 if err := mapstructure.Decode(param, arg.Interface()); err != nil { 41 return nil, fmt.Errorf("fn: mapstructure: %s", err.Error()) 42 } 43 fnParams[idx] = ensureType(arg.Elem(), fntyp.In(idx)) 44 case reflect.Slice: 45 rv := reflect.ValueOf(param) 46 // decode slice of structs to struct type using mapstructure 47 if fntyp.In(idx).Elem().Kind() == reflect.Struct { 48 nv := reflect.MakeSlice(fntyp.In(idx), rv.Len(), rv.Len()) 49 for i := 0; i < rv.Len(); i++ { 50 ref := reflect.New(nv.Index(i).Type()) 51 if err := mapstructure.Decode(rv.Index(i).Interface(), ref.Interface()); err != nil { 52 return nil, fmt.Errorf("fn: mapstructure: %s", err.Error()) 53 } 54 nv.Index(i).Set(reflect.Indirect(ref)) 55 } 56 rv = nv 57 } 58 fnParams[idx] = rv 59 default: 60 // if int is expected but got float64 assume json-like encoding and cast float to int 61 if fntyp.In(idx).Kind() == reflect.Int && reflect.TypeOf(param).Kind() == reflect.Float64 { 62 param = int(param.(float64)) 63 } 64 fnParams[idx] = ensureType(reflect.ValueOf(param), fntyp.In(idx)) 65 } 66 } 67 return fnParams, nil 68 } 69 70 // ParseReturn splits the results of reflect.Call() into the values, and 71 // possibly an error. 72 // If the last value is a non-nil error, this will return `nil, err`. 73 // If the last value is a nil error it will be removed from the value list. 74 // Any remaining values will be converted and returned as `any` typed values. 75 func ParseReturn(ret []reflect.Value) ([]any, error) { 76 if len(ret) == 0 { 77 return nil, nil 78 } 79 last := ret[len(ret)-1] 80 if last.Type().Implements(errorInterface) { 81 if !last.IsNil() { 82 return nil, last.Interface().(error) 83 } 84 ret = ret[:len(ret)-1] 85 } 86 out := make([]any, len(ret)) 87 for i, r := range ret { 88 out[i] = r.Interface() 89 } 90 return out, nil 91 } 92 93 // ensureType ensures a value is converted to the expected 94 // defined type from a convertable underlying type 95 func ensureType(v reflect.Value, t reflect.Type) reflect.Value { 96 if !v.IsValid() { 97 // handle nil values with zero value of expected type 98 return reflect.New(t).Elem() 99 } 100 nv := v 101 if v.Type().Kind() == reflect.Slice && v.Type().Elem() != t { 102 switch t.Kind() { 103 case reflect.Array: 104 nv = reflect.Indirect(reflect.New(t)) 105 for i := 0; i < v.Len(); i++ { 106 vv := reflect.ValueOf(v.Index(i).Interface()) 107 nv.Index(i).Set(vv.Convert(nv.Type().Elem())) 108 } 109 case reflect.Slice: 110 nv = reflect.MakeSlice(t, 0, 0) 111 for i := 0; i < v.Len(); i++ { 112 vv := reflect.ValueOf(v.Index(i).Interface()) 113 nv = reflect.Append(nv, vv.Convert(nv.Type().Elem())) 114 } 115 default: 116 panic("unable to convert slice to non-array, non-slice type") 117 } 118 } 119 if v.Type() != t { 120 nv = nv.Convert(t) 121 } 122 return nv 123 }