go-hep.org/x/hep@v0.38.1/groot/rtree/rvar.go (about)

     1  // Copyright ©2020 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rtree
     6  
     7  import (
     8  	"fmt"
     9  	"reflect"
    10  	"regexp"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"go-hep.org/x/hep/groot/root"
    15  	"golang.org/x/text/cases"
    16  	"golang.org/x/text/language"
    17  )
    18  
    19  func toTitle(s string) string {
    20  	return cases.Title(language.Und, cases.NoLower).String(s)
    21  }
    22  
    23  // ReadVar describes a variable to be read out of a tree.
    24  type ReadVar struct {
    25  	Name  string // name of the branch to read
    26  	Leaf  string // name of the leaf to read
    27  	Value any    // pointer to the value to fill
    28  
    29  	count string // name of the leaf-count, if any
    30  	leaf  Leaf   // leaf to which this read-var is bound
    31  }
    32  
    33  // NewReadVars returns the complete set of ReadVars to read all the data
    34  // contained in the provided Tree.
    35  func NewReadVars(t Tree) []ReadVar {
    36  	var vars []ReadVar
    37  	for _, b := range t.Branches() {
    38  		for _, leaf := range b.Leaves() {
    39  			ptr := newValue(leaf)
    40  			cnt := ""
    41  			if leaf.LeafCount() != nil {
    42  				cnt = leaf.LeafCount().Name()
    43  			}
    44  			vars = append(vars, ReadVar{Name: b.Name(), Leaf: leaf.Name(), Value: ptr, count: cnt, leaf: leaf})
    45  		}
    46  	}
    47  
    48  	return vars
    49  }
    50  
    51  // Deref returns the value pointed at by this read-var.
    52  func (rv ReadVar) Deref() any {
    53  	return reflect.ValueOf(rv.Value).Elem().Interface()
    54  }
    55  
    56  // ReadVarsFromStruct returns a list of ReadVars bound to the exported fields
    57  // of the provided pointer to a struct value.
    58  //
    59  // ReadVarsFromStruct panicks if the provided value is not a pointer to
    60  // a struct value.
    61  func ReadVarsFromStruct(ptr any) []ReadVar {
    62  	rv := reflect.ValueOf(ptr)
    63  	if rv.Kind() != reflect.Ptr {
    64  		panic(fmt.Errorf("rtree: expect a pointer value, got %T", ptr))
    65  	}
    66  
    67  	rv = rv.Elem()
    68  	if rv.Kind() != reflect.Struct {
    69  		panic(fmt.Errorf("rtree: expect a pointer to struct value, got %T", ptr))
    70  	}
    71  
    72  	var (
    73  		rt    = rv.Type()
    74  		rvars = make([]ReadVar, 0, rt.NumField())
    75  	)
    76  
    77  	for i := range rt.NumField() {
    78  		var (
    79  			ft = rt.Field(i)
    80  			fv = rv.Field(i)
    81  		)
    82  		if ft.Name != toTitle(ft.Name) {
    83  			// not exported. ignore.
    84  			continue
    85  		}
    86  
    87  		// check that struct-tag and field-type match
    88  		checkFieldTagConsistency(ft)
    89  
    90  		rvar := ReadVar{
    91  			Name:  nameOf(ft),
    92  			Value: fv.Addr().Interface(),
    93  		}
    94  
    95  		if strings.Contains(rvar.Name, "[") {
    96  			switch ft.Type.Kind() {
    97  			case reflect.Slice:
    98  				sli, dims := splitNameDims(rvar.Name)
    99  				if len(dims) > 1 {
   100  					panic(fmt.Errorf("rtree: invalid number of slice-dimensions for field %q: %q", ft.Name, rvar.Name))
   101  				}
   102  				rvar.Name = sli
   103  				rvar.count = dims[0]
   104  
   105  			case reflect.Array:
   106  				arr, dims := splitNameDims(rvar.Name)
   107  				if len(dims) > 3 {
   108  					panic(fmt.Errorf("rtree: invalid number of array-dimension for field %q: %q", ft.Name, rvar.Name))
   109  				}
   110  				rvar.Name = arr
   111  			default:
   112  				panic(fmt.Errorf("rtree: invalid field type for %q, or invalid struct-tag %q: %T", ft.Name, rvar.Name, fv.Interface()))
   113  			}
   114  		}
   115  		switch ft.Type.Kind() {
   116  		case reflect.Int, reflect.Uint, reflect.UnsafePointer, reflect.Uintptr, reflect.Chan, reflect.Interface:
   117  			panic(fmt.Errorf("rtree: invalid field type for %q: %T", ft.Name, fv.Interface()))
   118  		case reflect.Map:
   119  			panic(fmt.Errorf("rtree: invalid field type for %q: %T (not yet supported)", ft.Name, fv.Interface()))
   120  		}
   121  
   122  		rvar.Leaf = rvar.Name
   123  		rvars = append(rvars, rvar)
   124  	}
   125  	return rvars
   126  }
   127  
   128  func nameOf(field reflect.StructField) string {
   129  	tag, ok := field.Tag.Lookup("groot")
   130  	if ok {
   131  		if field.Type.Kind() != reflect.Array {
   132  			return tag
   133  		}
   134  
   135  		// regularize groot-tag for arrays.
   136  		// a groot use-case is to define a struct like so:
   137  		//
   138  		//   type T struct {
   139  		//		Array [1]int64 `groot:"array"`
   140  		//   }
   141  		//
   142  		// instead of the ROOT/C++ way:
   143  		//
   144  		//   type T struct {
   145  		//		Array [1]int64 `groot:"array[1]"
   146  		//	 }
   147  		//
   148  		// if the user didn't provide a dimension, build it.
   149  		if strings.Contains(tag, "[") {
   150  			return tag
   151  		}
   152  		dims := dimsOf(field.Type)
   153  		for _, dim := range dims {
   154  			tag += "[" + strconv.Itoa(dim) + "]"
   155  		}
   156  		return tag
   157  	}
   158  	return field.Name
   159  }
   160  
   161  func dimsOf(rt reflect.Type) []int {
   162  	var fct func(dims []int, rt reflect.Type) []int
   163  	fct = func(dims []int, rt reflect.Type) []int {
   164  		switch rt.Kind() {
   165  		case reflect.Array:
   166  			dims = append(dims, rt.Len())
   167  			dims = fct(dims, rt.Elem())
   168  		}
   169  		return dims
   170  	}
   171  
   172  	return fct(nil, rt)
   173  }
   174  
   175  func checkFieldTagConsistency(ft reflect.StructField) {
   176  	tag, ok := ft.Tag.Lookup("groot")
   177  	if !ok {
   178  		// nothing to check
   179  		return
   180  	}
   181  
   182  	rt := ft.Type
   183  	switch rt.Kind() {
   184  	case reflect.Array:
   185  		if !strings.Contains(tag, "[") {
   186  			// nothing to check
   187  			return
   188  		}
   189  		fromTyp := 1
   190  		for _, d := range dimsOf(rt) {
   191  			fromTyp *= d
   192  		}
   193  		fromTag := 1
   194  		_, rdims := splitNameDims(tag)
   195  		for _, s := range rdims {
   196  			v, err := strconv.Atoi(s)
   197  			if err != nil {
   198  				panic(fmt.Errorf("rtree: could not infer dimensions from %q: %w", tag, err))
   199  			}
   200  			fromTag *= v
   201  		}
   202  
   203  		if fromTyp != fromTag {
   204  			panic(fmt.Errorf("rtree: field type dimension inconsistency: groot-tag=%q vs go-type=%v: %d vs %d",
   205  				tag, ft.Type, fromTag, fromTyp,
   206  			))
   207  		}
   208  	}
   209  }
   210  
   211  // splitNameDims returns the name and dimensions of a ROOT
   212  // branch title.
   213  func splitNameDims(s string) (string, []string) {
   214  	reDims := regexp.MustCompile(`\w*?\[(\w*)\]+?`)
   215  	n := s
   216  	if i := strings.Index(s, "["); i > 0 {
   217  		n = s[:i]
   218  	}
   219  
   220  	out := reDims.FindAllStringSubmatch(s, -1)
   221  	if len(out) == 0 {
   222  		return n, nil
   223  	}
   224  
   225  	dims := make([]string, len(out))
   226  	for i := range out {
   227  		dims[i] = out[i][1]
   228  	}
   229  	return n, dims
   230  }
   231  
   232  func bindRVarsTo(t Tree, rvars []ReadVar) []ReadVar {
   233  	ors := make([]ReadVar, 0, len(rvars))
   234  	var flatten func(b Branch, rvar ReadVar) []ReadVar
   235  	flatten = func(br Branch, rvar ReadVar) []ReadVar {
   236  		nsub := len(br.Branches())
   237  		subs := make([]ReadVar, 0, nsub)
   238  		rv := reflect.ValueOf(rvar.Value).Elem()
   239  		get := func(name string) int {
   240  			rt := rv.Type()
   241  			for i := range rt.NumField() {
   242  				ft := rt.Field(i)
   243  				nn := nameOf(ft)
   244  				if nn == name {
   245  					// exact match.
   246  					return i
   247  				}
   248  				// try to remove any [xyz][range].
   249  				// do it after exact match not to shortcut arrays
   250  				if idx := strings.Index(nn, "["); idx > 0 {
   251  					nn = string(nn[:idx])
   252  				}
   253  				if nn == name {
   254  					return i
   255  				}
   256  			}
   257  			return -1
   258  		}
   259  
   260  		for _, sub := range br.Branches() {
   261  			bn := sub.Name()
   262  			if strings.Contains(bn, ".") {
   263  				toks := strings.Split(bn, ".")
   264  				bn = toks[len(toks)-1]
   265  			}
   266  			j := get(bn)
   267  			if j < 0 {
   268  				continue
   269  			}
   270  			fv := rv.Field(j)
   271  			bname := sub.Name()
   272  			lname := sub.Name()
   273  			if prefix := br.Name() + "."; strings.HasPrefix(bname, prefix) {
   274  				bname = string(bname[len(prefix):])
   275  			}
   276  			if idx := strings.Index(bname, "["); idx > 0 {
   277  				bname = string(bname[:idx])
   278  			}
   279  			if idx := strings.Index(lname, "["); idx > 0 {
   280  				lname = string(lname[:idx])
   281  			}
   282  			leaf := sub.Leaf(lname)
   283  			count := ""
   284  			if leaf != nil {
   285  				if lc := leaf.LeafCount(); lc != nil {
   286  					count = lc.Name()
   287  				}
   288  			}
   289  			subrv := ReadVar{
   290  				Name:  rvar.Name + "." + bname,
   291  				Leaf:  lname,
   292  				Value: fv.Addr().Interface(),
   293  				leaf:  leaf,
   294  				count: count,
   295  			}
   296  			switch len(sub.Branches()) {
   297  			case 0:
   298  				subs = append(subs, subrv)
   299  			default:
   300  				subs = append(subs, flatten(sub, subrv)...)
   301  			}
   302  		}
   303  		return subs
   304  	}
   305  
   306  	for i := range rvars {
   307  		var (
   308  			rvar = &rvars[i]
   309  			br   = t.Branch(rvar.Name)
   310  			leaf = br.Leaf(rvar.Leaf)
   311  			nsub = len(br.Branches())
   312  		)
   313  		switch nsub {
   314  		case 0:
   315  			rvar.leaf = leaf
   316  			ors = append(ors, *rvar)
   317  		default:
   318  			ors = append(ors, flatten(br, *rvar)...)
   319  		}
   320  	}
   321  	return ors
   322  }
   323  
   324  func newValue(leaf Leaf) any {
   325  	etype := leaf.Type()
   326  	unsigned := leaf.IsUnsigned()
   327  
   328  	switch etype.Kind() {
   329  	case reflect.Interface, reflect.Chan:
   330  		panic(fmt.Errorf("rtree: type %T not supported", reflect.New(etype).Elem().Interface()))
   331  	case reflect.Int8:
   332  		if unsigned {
   333  			etype = reflect.TypeOf(uint8(0))
   334  		}
   335  	case reflect.Int16:
   336  		if unsigned {
   337  			etype = reflect.TypeOf(uint16(0))
   338  		}
   339  	case reflect.Int32:
   340  		if unsigned {
   341  			etype = reflect.TypeOf(uint32(0))
   342  		}
   343  	case reflect.Int64:
   344  		if unsigned {
   345  			etype = reflect.TypeOf(uint64(0))
   346  		}
   347  	case reflect.Float32:
   348  		if _, ok := leaf.(*LeafF16); ok {
   349  			etype = reflect.TypeOf(root.Float16(0))
   350  		}
   351  	case reflect.Float64:
   352  		if _, ok := leaf.(*LeafD32); ok {
   353  			etype = reflect.TypeOf(root.Double32(0))
   354  		}
   355  	}
   356  
   357  	switch {
   358  	case leaf.LeafCount() != nil:
   359  		shape := leaf.Shape()
   360  		switch leaf.(type) {
   361  		case *LeafF16, *LeafD32:
   362  			// workaround for https://sft.its.cern.ch/jira/browse/ROOT-10149
   363  			shape = nil
   364  		}
   365  		for i := range shape {
   366  			etype = reflect.ArrayOf(shape[len(shape)-1-i], etype)
   367  		}
   368  		etype = reflect.SliceOf(etype)
   369  	case leaf.Len() > 1:
   370  		shape := leaf.Shape()
   371  		switch leaf.Kind() {
   372  		case reflect.String:
   373  			switch dims := len(shape); dims {
   374  			case 0, 1:
   375  				// interpret as a single string.
   376  			default:
   377  				// FIXME(sbinet): properly handle [N]string (but ROOT doesn't support that.)
   378  				// see: https://root-forum.cern.ch/t/char-t-in-a-branch/5591/2
   379  				// etype = reflect.ArrayOf(leaf.Len(), etype)
   380  				panic(fmt.Errorf("groot/rtree: invalid number of dimensions (%d)", dims))
   381  			}
   382  		default:
   383  			switch leaf.(type) {
   384  			case *LeafF16, *LeafD32:
   385  				// workaround for https://sft.its.cern.ch/jira/browse/ROOT-10149
   386  				shape = []int{leaf.Len()}
   387  			}
   388  			for i := range shape {
   389  				etype = reflect.ArrayOf(shape[len(shape)-1-i], etype)
   390  			}
   391  		}
   392  	}
   393  	return reflect.New(etype).Interface()
   394  }