go-hep.org/x/hep@v0.38.1/groot/rtree/rfunc/rfunc.go (about)

     1  // Copyright ©2020 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package rfunc provides types and funcs to implement user-provided formulae
     6  // evaluated on data exposed by ROOT trees.
     7  package rfunc // import "go-hep.org/x/hep/groot/rtree/rfunc"
     8  
     9  //go:generate go run ./gen-rfuncs.go
    10  
    11  import (
    12  	"fmt"
    13  	"reflect"
    14  )
    15  
    16  // Formula is the interface that describes the protocol between a user
    17  // provided function (that evaluates a value based on some data in a ROOT
    18  // tree) and the rtree.Reader (that presents data from a ROOT tree.)
    19  type Formula interface {
    20  	// RVars returns the names of the leaves that this formula needs.
    21  	// The returned slice must contain the names in the same order than the
    22  	// user formula function's arguments.
    23  	RVars() []string
    24  
    25  	// Bind provides the arguments to the user function.
    26  	// ptrs is a slice of pointers to the rtree.ReadVars, in the same order
    27  	// than requested by RVars.
    28  	Bind(ptrs []any) error
    29  
    30  	// Func returns the user function closing on the bound pointer-to-arguments
    31  	// and returning the expected evaluated value.
    32  	Func() any
    33  }
    34  
    35  // NewGenericFormula returns a new formula from the provided list of needed
    36  // tree variables and the provided user function.
    37  // NewGenericFormula uses reflect to bind read-vars and the generic function.
    38  func NewGenericFormula(rvars []string, fct any) (Formula, error) {
    39  	return newGenericFormula(rvars, fct)
    40  }
    41  
    42  type genericFormula struct {
    43  	names []string
    44  	fct   any
    45  
    46  	ptrs []reflect.Value
    47  	args []reflect.Value
    48  	out  []reflect.Value
    49  
    50  	rfct reflect.Value // formula-created function to eval read-vars
    51  	ufct reflect.Value // user-provided function
    52  }
    53  
    54  func newGenericFormula(names []string, fct any) (*genericFormula, error) {
    55  	rv := reflect.ValueOf(fct)
    56  	if rv.Kind() != reflect.Func {
    57  		return nil, fmt.Errorf("rfunc: formula expects a func")
    58  	}
    59  
    60  	if len(names) != rv.Type().NumIn() {
    61  		return nil, fmt.Errorf("rfunc: num-branches/func-arity mismatch")
    62  	}
    63  
    64  	if rv.Type().NumOut() != 1 {
    65  		// FIXME(sbinet): allow any kind of function?
    66  		return nil, fmt.Errorf("rfunc: invalid number of return values")
    67  	}
    68  
    69  	args := make([]reflect.Value, len(names))
    70  	ptrs := make([]reflect.Value, len(names))
    71  	for i := range args {
    72  		args[i] = reflect.New(rv.Type().In(i)).Elem()
    73  	}
    74  
    75  	gen := &genericFormula{
    76  		names: names,
    77  		fct:   fct,
    78  
    79  		ptrs: ptrs,
    80  		args: args,
    81  		ufct: rv,
    82  	}
    83  
    84  	gen.rfct = reflect.MakeFunc(
    85  		reflect.FuncOf(nil, []reflect.Type{rv.Type().Out(0)}, false),
    86  		func(in []reflect.Value) []reflect.Value {
    87  			gen.eval()
    88  			return gen.out
    89  		},
    90  	)
    91  
    92  	return gen, nil
    93  }
    94  
    95  func (f *genericFormula) RVars() []string { return f.names }
    96  func (f *genericFormula) Bind(args []any) error {
    97  	if got, want := len(args), len(f.ptrs); got != want {
    98  		return fmt.Errorf(
    99  			"rfunc: invalid number of bind arguments (got=%d, want=%d)",
   100  			got, want,
   101  		)
   102  	}
   103  
   104  	for i := range args {
   105  		var (
   106  			got  = reflect.TypeOf(args[i]).Elem()
   107  			want = f.args[i].Type()
   108  		)
   109  		if got != want {
   110  			return fmt.Errorf(
   111  				"rfunc: argument type %d (name=%s) mismatch: got=%T, want=%T",
   112  				i, f.names[i],
   113  				reflect.New(got).Elem().Interface(),
   114  				reflect.New(want).Elem().Interface(),
   115  			)
   116  		}
   117  		f.ptrs[i] = reflect.ValueOf(args[i])
   118  		f.args[i] = reflect.New(f.ptrs[i].Type().Elem()).Elem()
   119  	}
   120  
   121  	return nil
   122  }
   123  
   124  func (f *genericFormula) eval() {
   125  	for i := range f.ptrs {
   126  		f.args[i].Set(f.ptrs[i].Elem())
   127  	}
   128  	f.out = f.ufct.Call(f.args)
   129  }
   130  
   131  func (f *genericFormula) Func() any {
   132  	return f.rfct.Interface()
   133  }
   134  
   135  var (
   136  	_ Formula = (*genericFormula)(nil)
   137  )