github.com/maresnic/mr-kong@v1.0.0/callbacks.go (about)

     1  package kong
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strings"
     7  )
     8  
     9  type bindings map[reflect.Type]func() (reflect.Value, error)
    10  
    11  func (b bindings) String() string {
    12  	out := []string{}
    13  	for k := range b {
    14  		out = append(out, k.String())
    15  	}
    16  	return "bindings{" + strings.Join(out, ", ") + "}"
    17  }
    18  
    19  func (b bindings) add(values ...interface{}) bindings {
    20  	for _, v := range values {
    21  		v := v
    22  		b[reflect.TypeOf(v)] = func() (reflect.Value, error) { return reflect.ValueOf(v), nil }
    23  	}
    24  	return b
    25  }
    26  
    27  func (b bindings) addTo(impl, iface interface{}) {
    28  	valueOf := reflect.ValueOf(impl)
    29  	b[reflect.TypeOf(iface).Elem()] = func() (reflect.Value, error) { return valueOf, nil }
    30  }
    31  
    32  func (b bindings) addProvider(provider interface{}) error {
    33  	pv := reflect.ValueOf(provider)
    34  	t := pv.Type()
    35  	if t.Kind() != reflect.Func || t.NumIn() != 0 || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
    36  		return fmt.Errorf("%T must be a function with the signature func()(T, error)", provider)
    37  	}
    38  	rt := pv.Type().Out(0)
    39  	b[rt] = func() (reflect.Value, error) {
    40  		out := pv.Call(nil)
    41  		errv := out[1]
    42  		var err error
    43  		if !errv.IsNil() {
    44  			err = errv.Interface().(error) //nolint
    45  		}
    46  		return out[0], err
    47  	}
    48  	return nil
    49  }
    50  
    51  // Clone and add values.
    52  func (b bindings) clone() bindings {
    53  	out := make(bindings, len(b))
    54  	for k, v := range b {
    55  		out[k] = v
    56  	}
    57  	return out
    58  }
    59  
    60  func (b bindings) merge(other bindings) bindings {
    61  	for k, v := range other {
    62  		b[k] = v
    63  	}
    64  	return b
    65  }
    66  
    67  func getMethod(value reflect.Value, name string) reflect.Value {
    68  	method := value.MethodByName(name)
    69  	if !method.IsValid() {
    70  		if value.CanAddr() {
    71  			method = value.Addr().MethodByName(name)
    72  		}
    73  	}
    74  	return method
    75  }
    76  
    77  func callFunction(f reflect.Value, bindings bindings) error {
    78  	if f.Kind() != reflect.Func {
    79  		return fmt.Errorf("expected function, got %s", f.Type())
    80  	}
    81  	in := []reflect.Value{}
    82  	t := f.Type()
    83  	if t.NumOut() != 1 || !t.Out(0).Implements(callbackReturnSignature) {
    84  		return fmt.Errorf("return value of %s must implement \"error\"", t)
    85  	}
    86  	for i := 0; i < t.NumIn(); i++ {
    87  		pt := t.In(i)
    88  		if argf, ok := bindings[pt]; ok {
    89  			argv, err := argf()
    90  			if err != nil {
    91  				return err
    92  			}
    93  			in = append(in, argv)
    94  		} else {
    95  			return fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt)
    96  		}
    97  	}
    98  	out := f.Call(in)
    99  	if out[0].IsNil() {
   100  		return nil
   101  	}
   102  	return out[0].Interface().(error) //nolint
   103  }
   104  
   105  func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error) {
   106  	if f.Kind() != reflect.Func {
   107  		return nil, fmt.Errorf("expected function, got %s", f.Type())
   108  	}
   109  	in := []reflect.Value{}
   110  	t := f.Type()
   111  	for i := 0; i < t.NumIn(); i++ {
   112  		pt := t.In(i)
   113  		if argf, ok := bindings[pt]; ok {
   114  			argv, err := argf()
   115  			if err != nil {
   116  				return nil, err
   117  			}
   118  			in = append(in, argv)
   119  		} else {
   120  			return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt)
   121  		}
   122  	}
   123  	outv := f.Call(in)
   124  	out = make([]any, len(outv))
   125  	for i, v := range outv {
   126  		out[i] = v.Interface()
   127  	}
   128  	return out, nil
   129  }