go-hep.org/x/hep@v0.38.1/groot/cmd/root-gen-datareader/main.go (about)

     1  // Copyright ©2017 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Command root-gen-datareader generates a Go struct to easily read the
     6  // event data type stored inside a Tree.
     7  //
     8  // Example:
     9  //
    10  //	$> root-gen-datareader -t tree testdata/small-flat-tree.root
    11  //	// automatically generated by root-gen-datareader.
    12  //	// DO NOT EDIT.
    13  //
    14  //	package event
    15  //
    16  //	type Data struct {
    17  //		ROOT_Int32        int32       `groot:"Int32"`
    18  //		ROOT_Int64        int64       `groot:"Int64"`
    19  //		ROOT_UInt32       int32       `groot:"UInt32"`
    20  //		ROOT_UInt64       int64       `groot:"UInt64"`
    21  //		ROOT_Float32      float32     `groot:"Float32"`
    22  //		ROOT_Float64      float64     `groot:"Float64"`
    23  //		ROOT_ArrayInt32   [10]int32   `groot:"ArrayInt32[10]"`
    24  //		ROOT_ArrayInt64   [10]int64   `groot:"ArrayInt64[10]"`
    25  //		ROOT_ArrayInt32   [10]int32   `groot:"ArrayInt32[10]"`
    26  //		ROOT_ArrayInt64   [10]int64   `groot:"ArrayInt64[10]"`
    27  //		ROOT_ArrayFloat32 [10]float32 `groot:"ArrayFloat32[10]"`
    28  //		ROOT_ArrayFloat64 [10]float64 `groot:"ArrayFloat64[10]"`
    29  //		ROOT_N            int32       `groot:"N"`
    30  //		ROOT_SliceInt32   []int32     `groot:"SliceInt32[N]"`
    31  //		ROOT_SliceInt64   []int64     `groot:"SliceInt64[N]"`
    32  //		ROOT_SliceInt32   []int32     `groot:"SliceInt32[N]"`
    33  //		ROOT_SliceInt64   []int64     `groot:"SliceInt64[N]"`
    34  //		ROOT_SliceFloat32 []float32   `groot:"SliceFloat32[N]"`
    35  //		ROOT_SliceFloat64 []float64   `groot:"SliceFloat64[N]"`
    36  //	}
    37  package main // import "go-hep.org/x/hep/groot/cmd/root-gen-datareader"
    38  
    39  import (
    40  	"bytes"
    41  	"flag"
    42  	"fmt"
    43  	"go/format"
    44  	"io"
    45  	"log"
    46  	"os"
    47  	"reflect"
    48  	"strings"
    49  	"text/template"
    50  
    51  	"go-hep.org/x/hep/groot"
    52  	"go-hep.org/x/hep/groot/riofs"
    53  	"go-hep.org/x/hep/groot/rtree"
    54  )
    55  
    56  func main() {
    57  
    58  	log.SetPrefix("root-gen-datareader: ")
    59  	log.SetFlags(0)
    60  
    61  	var (
    62  		treeName   = flag.String("t", "tree", "name of the tree to inspect")
    63  		pkgName    = flag.String("p", "event", "name of the package where to generate the data model")
    64  		outName    = flag.String("o", "", "name of the file where to store the generated data model (STDOUT)")
    65  		dataReader = flag.Bool("reader", false, "generate data reader code")
    66  		verbose    = flag.Bool("v", false, "enable verbose mode")
    67  	)
    68  
    69  	flag.Parse()
    70  
    71  	if flag.NArg() <= 0 {
    72  		log.Printf("missing input file name")
    73  		flag.Usage()
    74  		flag.PrintDefaults()
    75  		os.Exit(1)
    76  	}
    77  
    78  	ctx := newContext(*pkgName, flag.Arg(0), *treeName, *dataReader, *verbose)
    79  
    80  	var o io.Writer = os.Stdout
    81  	if *outName != "" {
    82  		f, err := os.Create(*outName)
    83  		if err != nil {
    84  			log.Fatalf("could not create output file: %+v", err)
    85  		}
    86  		defer f.Close()
    87  		o = f
    88  	}
    89  
    90  	err := process(o, ctx)
    91  	if err != nil {
    92  		log.Fatalf("could not generate data-reader: %+v", err)
    93  	}
    94  }
    95  
    96  func process(w io.Writer, ctx *Context) error {
    97  	f, err := groot.Open(ctx.File)
    98  	if err != nil {
    99  		log.Fatal(err)
   100  	}
   101  	defer f.Close()
   102  
   103  	obj, err := riofs.Dir(f).Get(ctx.Tree)
   104  	if err != nil {
   105  		return fmt.Errorf("could not retrieve tree %q: %w", ctx.Tree, err)
   106  	}
   107  	var (
   108  		tree  = obj.(rtree.Tree)
   109  		rvars = rtree.NewReadVars(tree)
   110  	)
   111  	ctx.printf("entries: %v\n", tree.Entries())
   112  
   113  	for i, rvar := range rvars {
   114  		ctx.printf("rvar[%03d]: %q (%v)", i, rvar.Name, reflect.Indirect(reflect.ValueOf(rvar.Value)).Kind())
   115  		rv := reflect.Indirect(reflect.ValueOf(rvar.Value))
   116  		switch rv.Kind() {
   117  		case reflect.Struct:
   118  			def := structDefFrom(ctx, rvar.Name, reflect.Indirect(reflect.ValueOf(rvar.Value)).Type())
   119  			ctx.Defs[rvar.Name] = def
   120  			ctx.DataReader.Fields = append(
   121  				ctx.DataReader.Fields,
   122  				FieldDef{
   123  					Name:    goName(rvar.Name),
   124  					Tag:     rvar.Name,
   125  					VarName: rvar.Name,
   126  					Type:    def.Name,
   127  				},
   128  			)
   129  		case reflect.Array:
   130  			ctx.DataReader.Fields = append(
   131  				ctx.DataReader.Fields,
   132  				FieldDef{
   133  					Name:    goName(rvar.Name),
   134  					Tag:     fmt.Sprintf("%s[%d]", rvar.Name, rv.Type().Len()),
   135  					VarName: rvar.Name,
   136  					Type:    fmt.Sprintf("%T", rv.Interface()),
   137  				},
   138  			)
   139  			ctx.checkType(rv.Type().Elem())
   140  
   141  		case reflect.Slice:
   142  			ctx.DataReader.Fields = append(
   143  				ctx.DataReader.Fields,
   144  				FieldDef{
   145  					Name:    goName(rvar.Name),
   146  					Tag:     rvar.Name,
   147  					VarName: rvar.Name,
   148  					Type:    fmt.Sprintf("%T", rv.Interface()),
   149  				},
   150  			)
   151  			ctx.checkType(rv.Type().Elem())
   152  
   153  		default:
   154  			ctx.DataReader.Fields = append(
   155  				ctx.DataReader.Fields,
   156  				FieldDef{
   157  					Name:    goName(rvar.Name),
   158  					Tag:     rvar.Name,
   159  					VarName: rvar.Name,
   160  					Type:    fmt.Sprintf("%T", rv.Interface()),
   161  				},
   162  			)
   163  			ctx.checkType(rv.Type())
   164  		}
   165  	}
   166  	delete(ctx.Defs, "DataReader")
   167  
   168  	err = ctx.genCode(w)
   169  	if err != nil {
   170  		return fmt.Errorf("could not generate reader code: %w", err)
   171  	}
   172  
   173  	return nil
   174  }
   175  
   176  func goName(s string) string {
   177  	return "ROOT_" + s
   178  }
   179  
   180  // StructDef models a TTree content as a Go struct, where each TBranch
   181  // of the tree is translated as a field of the struct.
   182  type StructDef struct {
   183  	Name   string
   184  	Fields []FieldDef
   185  }
   186  
   187  // FieldDef describes a Go struct field, corresponding to a TTree's branch.
   188  type FieldDef struct {
   189  	Name    string
   190  	Type    string
   191  	VarName string
   192  	Tag     string
   193  }
   194  
   195  // Context holds together various informations about the TTree being processed.
   196  type Context struct {
   197  	Package       string
   198  	Imports       map[string]int
   199  	DataReader    *StructDef
   200  	Defs          map[string]*StructDef
   201  	GenDataReader bool
   202  	File          string
   203  	Tree          string
   204  	Verbose       bool
   205  }
   206  
   207  func newContext(pkg, file, tree string, dataReader, verbose bool) *Context {
   208  	ctx := &Context{
   209  		Package: pkg,
   210  		Imports: make(map[string]int),
   211  		Defs: map[string]*StructDef{
   212  			"DataReader": {
   213  				Name:   "DataReader",
   214  				Fields: nil,
   215  			},
   216  		},
   217  		GenDataReader: dataReader,
   218  		File:          file,
   219  		Tree:          tree,
   220  		Verbose:       verbose,
   221  	}
   222  	ctx.DataReader = ctx.Defs["DataReader"]
   223  	if dataReader {
   224  		ctx.Imports["go-hep.org/x/hep/groot/rtree"]++
   225  	}
   226  	return ctx
   227  }
   228  
   229  func (ctx *Context) printf(format string, args ...any) {
   230  	if ctx.Verbose {
   231  		log.Printf(format, args...)
   232  	}
   233  }
   234  
   235  func (ctx *Context) genCode(w io.Writer) error {
   236  	t := template.New("top")
   237  	template.Must(t.Parse(codeTmpl))
   238  	buf := new(bytes.Buffer)
   239  	err := t.Execute(buf, ctx)
   240  	if err != nil {
   241  		return err
   242  	}
   243  	src, err := format.Source(buf.Bytes())
   244  	if err != nil {
   245  		log.Printf("source:\n%s\n", buf.Bytes())
   246  		log.Printf("error: %+v", err)
   247  		return err
   248  	}
   249  	_, err = w.Write(src)
   250  	return err
   251  }
   252  
   253  const codeTmpl = `// automatically generated by root-gen-datareader.
   254  // DO NOT EDIT.
   255  
   256  package {{.Package}}
   257  
   258  {{$length := len .Imports}}{{if gt $length 0}}
   259  import (
   260  {{range $key, $value := .Imports}}
   261  "{{$key}}"
   262  {{- end}}
   263  )
   264  {{- end}}
   265  
   266  {{range .Defs}}
   267  type {{.Name}} struct {
   268  {{range .Fields}}	{{.Name}} {{.Type}} ` + "`groot:\"{{.Tag}}\"`" + `
   269  {{end}}}
   270  {{end}}
   271  
   272  
   273  {{with .DataReader}}
   274  // Data is the data contained in a rtree.Tree.
   275  type Data struct {
   276  {{ range .Fields}}	{{.Name}} {{.Type}} ` + "`groot:\"{{.Tag}}\"`" + `
   277  {{ end}}}
   278  {{end}}
   279  
   280  {{if .GenDataReader}}
   281  {{with .DataReader}}
   282  type DataReader struct {
   283  	Data   Data
   284  	Tree   rtree.Tree
   285  	Reader *rtree.Reader
   286  }
   287  {{end}}
   288  {{end}}
   289  `
   290  
   291  func structDefFrom(ctx *Context, name string, rt reflect.Type) *StructDef {
   292  	def := StructDef{
   293  		Name:   name,
   294  		Fields: make([]FieldDef, rt.NumField()),
   295  	}
   296  	for i := range def.Fields {
   297  		ft := rt.Field(i)
   298  		def.Fields[i] = fieldDefFrom(ctx, ft)
   299  	}
   300  
   301  	return &def
   302  }
   303  
   304  func fieldDefFrom(ctx *Context, typ reflect.StructField) FieldDef {
   305  	tag := typ.Tag.Get("groot")
   306  	ctx.checkType(typ.Type)
   307  
   308  	switch typ.Type.Kind() {
   309  	case reflect.Struct:
   310  		branch := tag
   311  		if i := strings.Index(branch, "["); i > 0 {
   312  			branch = branch[:i]
   313  		}
   314  		ctx.Defs[branch] = structDefFrom(ctx, branch, typ.Type)
   315  		return FieldDef{
   316  			Name:    typ.Name,
   317  			Type:    branch,
   318  			VarName: tag,
   319  			Tag:     tag,
   320  		}
   321  	default:
   322  		return FieldDef{
   323  			Name:    typ.Name,
   324  			Type:    typ.Type.String(),
   325  			VarName: tag,
   326  			Tag:     tag,
   327  		}
   328  	}
   329  }
   330  
   331  func (ctx *Context) checkType(typ reflect.Type) {
   332  	pkg := typ.PkgPath()
   333  	if pkg != "" {
   334  		ctx.Imports[pkg]++
   335  	}
   336  }