go-hep.org/x/hep@v0.38.1/groot/rtree/formula_test.go (about)

     1  // Copyright ©2020 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rtree
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"reflect"
    11  	"testing"
    12  
    13  	"go-hep.org/x/hep/groot/riofs"
    14  	"go-hep.org/x/hep/groot/root"
    15  )
    16  
    17  func TestFormulaFunc(t *testing.T) {
    18  	for _, tc := range []struct {
    19  		fname    string
    20  		tname    string
    21  		rvars    int
    22  		fct      any
    23  		branches []string
    24  		want     []any
    25  		err      error
    26  	}{
    27  		{
    28  			fname:    "../testdata/simple.root",
    29  			tname:    "tree",
    30  			rvars:    -1,
    31  			fct:      func(x int32) int32 { return x },
    32  			branches: []string{"one"},
    33  			want:     []any{int32(1), int32(2)},
    34  		},
    35  		{
    36  			fname: "../testdata/simple.root",
    37  			tname: "tree",
    38  			rvars: -1,
    39  			fct: func(x1 int32, x2 float32) float64 {
    40  				return float64(x1) + float64(x2*100)
    41  			},
    42  			branches: []string{"one", "two"},
    43  			want:     []any{float64(111), float64(222)},
    44  		},
    45  		{
    46  			fname: "../testdata/simple.root",
    47  			tname: "tree",
    48  			rvars: 0,
    49  			fct: func(x1 int32, x2 float32) float64 {
    50  				return float64(x1) + float64(x2*100)
    51  			},
    52  			branches: []string{"one", "two"},
    53  			want:     []any{float64(111), float64(222)},
    54  		},
    55  		{
    56  			fname: "../testdata/simple.root",
    57  			tname: "tree",
    58  			rvars: 1,
    59  			fct: func(x1 int32, x2 float32) float64 {
    60  				return float64(x1) + float64(x2*100)
    61  			},
    62  			branches: []string{"one", "two"},
    63  			want:     []any{float64(111), float64(222)},
    64  		},
    65  		{
    66  			fname: "../testdata/simple.root",
    67  			tname: "tree",
    68  			rvars: -1,
    69  			fct: func(x1 int32) int32 {
    70  				return x1 * x1
    71  			},
    72  			branches: []string{"one"},
    73  			want:     []any{int32(1), int32(4)},
    74  		},
    75  		{
    76  			fname: "../testdata/simple.root",
    77  			tname: "tree",
    78  			rvars: -1,
    79  			fct: func(x1 int32) float64 {
    80  				return math.Sqrt(float64(x1 * x1))
    81  			},
    82  			branches: []string{"one"},
    83  			want:     []any{float64(1), float64(2)},
    84  		},
    85  		{
    86  			fname: "../testdata/simple.root",
    87  			tname: "tree",
    88  			rvars: -1,
    89  			fct: func(x1 int32) string {
    90  				return fmt.Sprintf("%d", x1)
    91  			},
    92  			branches: []string{"one"},
    93  			want:     []any{"1", "2"},
    94  		},
    95  		{
    96  			fname: "../testdata/leaves.root",
    97  			tname: "tree",
    98  			rvars: -1,
    99  			fct: func(x [10]uint64) [10]uint64 {
   100  				return x
   101  			},
   102  			branches: []string{"ArrU64"},
   103  			want:     []any{[10]uint64{}, [10]uint64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}},
   104  		},
   105  		{
   106  			fname: "../testdata/leaves.root",
   107  			tname: "tree",
   108  			rvars: -1,
   109  			fct: func(x [10]uint64) uint64 {
   110  				return x[0]
   111  			},
   112  			branches: []string{"ArrU64"},
   113  			want:     []any{uint64(0), uint64(1)},
   114  		},
   115  		{
   116  			fname: "../testdata/leaves.root",
   117  			tname: "tree",
   118  			rvars: -1,
   119  			fct: func(x []float32) []float64 {
   120  				o := make([]float64, len(x))
   121  				for i, v := range x {
   122  					o[i] = float64(2 * v)
   123  				}
   124  				return o
   125  			},
   126  			branches: []string{"SliF32"},
   127  			want:     []any{[]float64{}, []float64{2}},
   128  		},
   129  		{
   130  			fname: "../testdata/small-evnt-tree-fullsplit.root",
   131  			tname: "tree",
   132  			rvars: -1,
   133  			fct: func(x []float32) []float64 {
   134  				o := make([]float64, len(x))
   135  				for i, v := range x {
   136  					o[i] = float64(2 * v)
   137  				}
   138  				return o
   139  			},
   140  			branches: []string{"evt.SliceF32"},
   141  			want:     []any{[]float64{}, []float64{2}},
   142  		},
   143  		{
   144  			fname: "../testdata/small-evnt-tree-fullsplit.root",
   145  			tname: "tree",
   146  			rvars: -1,
   147  			fct: func(x []float32) []float64 {
   148  				o := make([]float64, len(x))
   149  				for i, v := range x {
   150  					o[i] = float64(2 * v)
   151  				}
   152  				return o
   153  			},
   154  			branches: []string{"evt.StlVecF32"},
   155  			want:     []any{[]float64{}, []float64{2}},
   156  		},
   157  		{
   158  			fname: "../testdata/embedded-std-vector.root",
   159  			tname: "modules",
   160  			rvars: -1,
   161  			fct: func(x []float32) []float64 {
   162  				o := make([]float64, len(x))
   163  				for i, v := range x {
   164  					o[i] = float64(2 * v)
   165  				}
   166  				return o
   167  			},
   168  			branches: []string{"hits_time_mc"},
   169  			want: []any{
   170  				[]float64{
   171  					24.412797927856445, 23.422243118286133,
   172  					23.469839096069336, 24.914079666137695,
   173  					23.116113662719727, 23.13003921508789,
   174  					23.375518798828125, 23.057828903198242,
   175  					25.786481857299805, 22.85857582092285,
   176  				},
   177  				[]float64{
   178  					23.436037063598633, 25.970693588256836,
   179  					24.462419509887695, 23.650163650512695,
   180  					24.811952590942383, 30.67894172668457,
   181  					23.878101348876953, 25.87006378173828,
   182  					27.323381423950195, 23.939083099365234,
   183  					23.786226272583008,
   184  				},
   185  			},
   186  		},
   187  		{
   188  			fname: "../testdata/leaves.root",
   189  			tname: "tree",
   190  			rvars: -1,
   191  			fct: func(x root.Float16) root.Float16 {
   192  				return x
   193  			},
   194  			branches: []string{"D16"},
   195  			want:     []any{root.Float16(0.0), root.Float16(1.0)},
   196  		},
   197  		{
   198  			fname: "../testdata/leaves.root",
   199  			tname: "tree",
   200  			rvars: -1,
   201  			fct: func(x root.Double32) root.Double32 {
   202  				return x
   203  			},
   204  			branches: []string{"D32"},
   205  			want:     []any{root.Double32(0.0), root.Double32(1.0)},
   206  		},
   207  		{
   208  			fname: "../testdata/leaves.root",
   209  			tname: "tree",
   210  			rvars: -1,
   211  			fct: func(x [10]root.Double32) root.Double32 {
   212  				return x[0]
   213  			},
   214  			branches: []string{"ArrD32"},
   215  			want:     []any{root.Double32(0), root.Double32(1)},
   216  		},
   217  		{
   218  			fname: "../testdata/leaves.root",
   219  			tname: "tree",
   220  			rvars: -1,
   221  			fct: func(x1 root.Double32, x2 []int64) float64 {
   222  				return float64(x1) + float64(len(x2))
   223  			},
   224  			branches: []string{"D32", "SliI64"},
   225  			want:     []any{0.0, 2.0},
   226  		},
   227  		{
   228  			fname: "../testdata/leaves.root",
   229  			tname: "tree",
   230  			rvars: -1,
   231  			fct: func() float64 {
   232  				return 42.0
   233  			},
   234  			branches: nil,
   235  			want:     []any{42.0, 42.0},
   236  		},
   237  		{
   238  			fname:    "../testdata/leaves.root",
   239  			tname:    "tree",
   240  			rvars:    -1,
   241  			fct:      func(v bool) bool { return v },
   242  			branches: []string{"B"},
   243  			want:     []any{true, false},
   244  		},
   245  		{
   246  			fname:    "../testdata/leaves.root",
   247  			tname:    "tree",
   248  			rvars:    -1,
   249  			fct:      func(v int8) int8 { return v },
   250  			branches: []string{"I8"},
   251  			want:     []any{int8(0), int8(-1)},
   252  		},
   253  		{
   254  			fname:    "../testdata/leaves.root",
   255  			tname:    "tree",
   256  			rvars:    -1,
   257  			fct:      func(v int16) int16 { return v },
   258  			branches: []string{"I16"},
   259  			want:     []any{int16(0), int16(-1)},
   260  		},
   261  		{
   262  			fname:    "../testdata/leaves.root",
   263  			tname:    "tree",
   264  			rvars:    -1,
   265  			fct:      func(v int32) int32 { return v },
   266  			branches: []string{"I32"},
   267  			want:     []any{int32(0), int32(-1)},
   268  		},
   269  		{
   270  			fname:    "../testdata/leaves.root",
   271  			tname:    "tree",
   272  			rvars:    -1,
   273  			fct:      func(v int64) int64 { return v },
   274  			branches: []string{"I64"},
   275  			want:     []any{int64(0), int64(-1)},
   276  		},
   277  		{
   278  			fname:    "../testdata/leaves.root",
   279  			tname:    "tree",
   280  			rvars:    -1,
   281  			fct:      func(v uint8) uint8 { return v },
   282  			branches: []string{"U8"},
   283  			want:     []any{uint8(0), uint8(1)},
   284  		},
   285  		{
   286  			fname:    "../testdata/leaves.root",
   287  			tname:    "tree",
   288  			rvars:    -1,
   289  			fct:      func(v uint16) uint16 { return v },
   290  			branches: []string{"U16"},
   291  			want:     []any{uint16(0), uint16(1)},
   292  		},
   293  		{
   294  			fname:    "../testdata/leaves.root",
   295  			tname:    "tree",
   296  			rvars:    -1,
   297  			fct:      func(v uint32) uint32 { return v },
   298  			branches: []string{"U32"},
   299  			want:     []any{uint32(0), uint32(1)},
   300  		},
   301  		{
   302  			fname:    "../testdata/leaves.root",
   303  			tname:    "tree",
   304  			rvars:    -1,
   305  			fct:      func(v uint64) uint64 { return v },
   306  			branches: []string{"U64"},
   307  			want:     []any{uint64(0), uint64(1)},
   308  		},
   309  		{
   310  			fname:    "../testdata/leaves.root",
   311  			tname:    "tree",
   312  			rvars:    -1,
   313  			fct:      func(v float32) float32 { return v },
   314  			branches: []string{"F32"},
   315  			want:     []any{float32(0), float32(1)},
   316  		},
   317  		{
   318  			fname:    "../testdata/leaves.root",
   319  			tname:    "tree",
   320  			rvars:    -1,
   321  			fct:      func(v float64) float64 { return v },
   322  			branches: []string{"F64"},
   323  			want:     []any{float64(0), float64(1)},
   324  		},
   325  		{
   326  			fname:    "../testdata/leaves.root",
   327  			tname:    "tree",
   328  			rvars:    -1,
   329  			fct:      func(v string) string { return v },
   330  			branches: []string{"Str"},
   331  			want:     []any{"str-0", "str-1"},
   332  		},
   333  		{
   334  			fname:    "../testdata/leaves.root",
   335  			tname:    "tree",
   336  			rvars:    -1,
   337  			fct:      func(v string) [1]string { return [1]string{v} },
   338  			branches: []string{"Str"},
   339  			want:     []any{[1]string{"str-0"}, [1]string{"str-1"}},
   340  		},
   341  		{
   342  			fname:    "../testdata/simple.root",
   343  			tname:    "tree",
   344  			rvars:    -1,
   345  			fct:      func(x int32) int32 { return x },
   346  			branches: []string{"ones"},
   347  			err:      fmt.Errorf(`rtree: could not create formula: rtree: could not find all needed ReadVars (missing: [ones])`),
   348  		},
   349  		{
   350  			fname:    "../testdata/simple.root",
   351  			tname:    "tree",
   352  			rvars:    -1,
   353  			fct:      func(x int32, y float32, z int32) int32 { return x },
   354  			branches: []string{"one", "twos", "ones"},
   355  			err:      fmt.Errorf(`rtree: could not create formula: rtree: could not find all needed ReadVars (missing: [twos ones])`),
   356  		},
   357  		{
   358  			fname:    "../testdata/simple.root",
   359  			tname:    "tree",
   360  			rvars:    -1,
   361  			fct:      func(x1 int32, x2 float64) float64 { return 0 },
   362  			branches: []string{"one", "two"},
   363  			err:      fmt.Errorf(`rtree: could not create formula: rtree: could not bind formula to rvars: rfunc: argument type 1 (name=two) mismatch: got=float32, want=float64`),
   364  		},
   365  		{
   366  			fname:    "../testdata/simple.root",
   367  			tname:    "tree",
   368  			rvars:    -1,
   369  			fct:      "not a func",
   370  			branches: []string{"one", "two"},
   371  			err:      fmt.Errorf(`rtree: could not create formula: rfunc: formula expects a func`),
   372  		},
   373  		{
   374  			fname:    "../testdata/simple.root",
   375  			tname:    "tree",
   376  			rvars:    -1,
   377  			fct:      func(x1 int32, x2 float64) float64 { return 0 },
   378  			branches: []string{"one"},
   379  			err:      fmt.Errorf(`rtree: could not create formula: rfunc: num-branches/func-arity mismatch`),
   380  		},
   381  		{
   382  			fname:    "../testdata/simple.root",
   383  			tname:    "tree",
   384  			rvars:    -1,
   385  			fct:      func(x1 int32) float64 { return 0 },
   386  			branches: []string{"one", "two"},
   387  			err:      fmt.Errorf(`rtree: could not create formula: rfunc: num-branches/func-arity mismatch`),
   388  		},
   389  		{
   390  			fname:    "../testdata/simple.root",
   391  			tname:    "tree",
   392  			rvars:    -1,
   393  			fct:      func(x1 int32) (a, b float64) { return },
   394  			branches: []string{"one"},
   395  			err:      fmt.Errorf(`rtree: could not create formula: rfunc: invalid number of return values`),
   396  		},
   397  	} {
   398  		t.Run("", func(t *testing.T) {
   399  			f, err := riofs.Open(tc.fname)
   400  			if err != nil {
   401  				t.Fatal(err)
   402  			}
   403  			defer f.Close()
   404  
   405  			o, err := riofs.Dir(f).Get(tc.tname)
   406  			if err != nil {
   407  				t.Fatal(err)
   408  			}
   409  
   410  			tree := o.(Tree)
   411  
   412  			var rvars []ReadVar
   413  			switch tc.rvars {
   414  			case -1:
   415  				rvars = NewReadVars(tree)
   416  			case 0:
   417  				rvars = nil
   418  			default:
   419  				rvars = NewReadVars(tree)[:tc.rvars]
   420  			}
   421  
   422  			r, err := NewReader(tree, rvars, WithRange(0, 2))
   423  			if err != nil {
   424  				t.Fatal(err)
   425  			}
   426  			defer r.Close()
   427  
   428  			form, err := r.FormulaFunc(tc.branches, tc.fct)
   429  			switch {
   430  			case err != nil && tc.err != nil:
   431  				if got, want := err.Error(), tc.err.Error(); got != want {
   432  					t.Fatalf("invalid error.\ngot= %v\nwant=%v", got, want)
   433  				}
   434  				return
   435  			case err != nil && tc.err == nil:
   436  				t.Fatalf("unexpected error: %+v", err)
   437  			case err == nil && tc.err != nil:
   438  				t.Fatalf("expected an error: %v (got=nil)", tc.err)
   439  			case err == nil && tc.err == nil:
   440  				// ok.
   441  			}
   442  
   443  			defer func() {
   444  				e := recover()
   445  				if e != nil {
   446  					t.Fatalf("could not run form-eval: %+v", e)
   447  				}
   448  			}()
   449  
   450  			err = r.Read(func(ctx RCtx) error {
   451  				if got, want := reflect.ValueOf(form.Func()).Call(nil)[0].Interface(), tc.want[ctx.Entry]; !reflect.DeepEqual(got, want) {
   452  					return fmt.Errorf("entry[%d]: invalid form-eval:\ngot= %v (%T)\nwant=%v (%T)", ctx.Entry, got, got, want, want)
   453  				}
   454  
   455  				return nil
   456  			})
   457  			if err != nil {
   458  				t.Fatalf("error: %+v", err)
   459  			}
   460  		})
   461  	}
   462  }
   463  
   464  var sumBenchFormulaFunc float64
   465  
   466  func BenchmarkFormulaFunc(b *testing.B) {
   467  	for _, tc := range []struct {
   468  		name string
   469  		fct  any
   470  		brs  []string
   471  	}{
   472  		{
   473  			name: "f0",
   474  			fct:  func() float64 { return 42 },
   475  		},
   476  		{
   477  			name: "f1",
   478  			fct:  func(x float64) float64 { return x },
   479  			brs:  []string{"F64"},
   480  		},
   481  		{
   482  			name: "f2",
   483  			fct:  func(x float64) float64 { return 2 * x },
   484  			brs:  []string{"F64"},
   485  		},
   486  		{
   487  			name: "f3",
   488  			fct:  func(x float64) float64 { return math.Abs(2 * x) },
   489  			brs:  []string{"F64"},
   490  		},
   491  	} {
   492  		b.Run(tc.name, func(b *testing.B) {
   493  			f, err := riofs.Open("../testdata/leaves.root")
   494  			if err != nil {
   495  				b.Fatal(err)
   496  			}
   497  			defer f.Close()
   498  
   499  			o, err := f.Get("tree")
   500  			if err != nil {
   501  				b.Fatal(err)
   502  			}
   503  			tree := o.(Tree)
   504  
   505  			r, err := NewReader(tree, nil)
   506  			if err != nil {
   507  				b.Fatal(err)
   508  			}
   509  
   510  			form, err := r.FormulaFunc(tc.brs, tc.fct)
   511  			if err != nil {
   512  				b.Fatal(err)
   513  			}
   514  
   515  			b.Run("Func", func(b *testing.B) {
   516  				eval := form.Func().(func() float64)
   517  
   518  				err = r.Read(func(ctx RCtx) error {
   519  					sumBenchFormulaFunc += eval()
   520  					return nil
   521  				})
   522  				if err != nil {
   523  					b.Fatalf("error: %+v", err)
   524  				}
   525  
   526  				r.r.reset()
   527  
   528  				sumBenchFormulaFunc = 0
   529  				b.ReportAllocs()
   530  				b.ResetTimer()
   531  
   532  				for i := 0; i < b.N; i++ {
   533  					sumBenchFormulaFunc += eval()
   534  				}
   535  			})
   536  		})
   537  	}
   538  }