github.com/alecthomas/kong@v0.9.1-0.20240410131203-2ab5733f1179/callbacks.go (about) 1 package kong 2 3 import ( 4 "fmt" 5 "reflect" 6 "strings" 7 ) 8 9 type bindings map[reflect.Type]func() (reflect.Value, error) 10 11 func (b bindings) String() string { 12 out := []string{} 13 for k := range b { 14 out = append(out, k.String()) 15 } 16 return "bindings{" + strings.Join(out, ", ") + "}" 17 } 18 19 func (b bindings) add(values ...interface{}) bindings { 20 for _, v := range values { 21 v := v 22 b[reflect.TypeOf(v)] = func() (reflect.Value, error) { return reflect.ValueOf(v), nil } 23 } 24 return b 25 } 26 27 func (b bindings) addTo(impl, iface interface{}) { 28 valueOf := reflect.ValueOf(impl) 29 b[reflect.TypeOf(iface).Elem()] = func() (reflect.Value, error) { return valueOf, nil } 30 } 31 32 func (b bindings) addProvider(provider interface{}) error { 33 pv := reflect.ValueOf(provider) 34 t := pv.Type() 35 if t.Kind() != reflect.Func || t.NumIn() != 0 || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() { 36 return fmt.Errorf("%T must be a function with the signature func()(T, error)", provider) 37 } 38 rt := pv.Type().Out(0) 39 b[rt] = func() (reflect.Value, error) { 40 out := pv.Call(nil) 41 errv := out[1] 42 var err error 43 if !errv.IsNil() { 44 err = errv.Interface().(error) //nolint 45 } 46 return out[0], err 47 } 48 return nil 49 } 50 51 // Clone and add values. 52 func (b bindings) clone() bindings { 53 out := make(bindings, len(b)) 54 for k, v := range b { 55 out[k] = v 56 } 57 return out 58 } 59 60 func (b bindings) merge(other bindings) bindings { 61 for k, v := range other { 62 b[k] = v 63 } 64 return b 65 } 66 67 func getMethod(value reflect.Value, name string) reflect.Value { 68 method := value.MethodByName(name) 69 if !method.IsValid() { 70 if value.CanAddr() { 71 method = value.Addr().MethodByName(name) 72 } 73 } 74 return method 75 } 76 77 func callFunction(f reflect.Value, bindings bindings) error { 78 if f.Kind() != reflect.Func { 79 return fmt.Errorf("expected function, got %s", f.Type()) 80 } 81 in := []reflect.Value{} 82 t := f.Type() 83 if t.NumOut() != 1 || !t.Out(0).Implements(callbackReturnSignature) { 84 return fmt.Errorf("return value of %s must implement \"error\"", t) 85 } 86 for i := 0; i < t.NumIn(); i++ { 87 pt := t.In(i) 88 if argf, ok := bindings[pt]; ok { 89 argv, err := argf() 90 if err != nil { 91 return err 92 } 93 in = append(in, argv) 94 } else { 95 return fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt) 96 } 97 } 98 out := f.Call(in) 99 if out[0].IsNil() { 100 return nil 101 } 102 return out[0].Interface().(error) //nolint 103 } 104 105 func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error) { 106 if f.Kind() != reflect.Func { 107 return nil, fmt.Errorf("expected function, got %s", f.Type()) 108 } 109 in := []reflect.Value{} 110 t := f.Type() 111 for i := 0; i < t.NumIn(); i++ { 112 pt := t.In(i) 113 if argf, ok := bindings[pt]; ok { 114 argv, err := argf() 115 if err != nil { 116 return nil, err 117 } 118 in = append(in, argv) 119 } else { 120 return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt) 121 } 122 } 123 outv := f.Call(in) 124 out = make([]any, len(outv)) 125 for i, v := range outv { 126 out[i] = v.Interface() 127 } 128 return out, nil 129 }