go-hep.org/x/hep@v0.38.1/groot/internal/genroot/genrfunc.go (about)

     1  // Copyright ©2020 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package genroot // import "go-hep.org/x/hep/groot/internal/genroot"
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/parser"
    11  	"go/token"
    12  	"go/types"
    13  	"io"
    14  	"strconv"
    15  	"strings"
    16  	"text/template"
    17  
    18  	"golang.org/x/tools/go/packages"
    19  )
    20  
    21  // RFunc describes which function should be used as a template
    22  // to implement the rtree/rfunc.Formula interface.
    23  type RFunc struct {
    24  	Pkg  string // Name of package hosting the formula to be generated.
    25  	Path string // Import path of the package holding the function.
    26  	Name string // Formula name.
    27  	Def  string // Function name or signature.
    28  }
    29  
    30  // GenRFunc generates the rtree/rfunc.Formula implementation for fct.
    31  func GenRFunc(w io.Writer, fct RFunc) error {
    32  	gen, err := NewRFuncGenerator(w, fct)
    33  	if err != nil {
    34  		return fmt.Errorf("genroot: could not create rfunc generator: %w", err)
    35  	}
    36  
    37  	err = gen.Generate()
    38  	if err != nil {
    39  		return fmt.Errorf("genroot: could not generate rfunc formula implementation: %w", err)
    40  	}
    41  	return nil
    42  }
    43  
    44  type rfuncGen struct {
    45  	w    io.Writer
    46  	f    *types.Signature
    47  	pkg  string // "rfunc." or ""
    48  	name string
    49  }
    50  
    51  func NewRFuncGenerator(w io.Writer, fct RFunc) (*rfuncGen, error) {
    52  	var (
    53  		f   *types.Signature
    54  		err error
    55  	)
    56  	switch fct.Path {
    57  	case "":
    58  		f, err = parseExpr(fct.Def)
    59  		if err != nil {
    60  			return nil, fmt.Errorf("genroot: could not parse function signature: %w", err)
    61  		}
    62  	default:
    63  		cfg := &packages.Config{
    64  			Mode: packages.NeedName |
    65  				packages.NeedFiles |
    66  				packages.NeedCompiledGoFiles |
    67  				packages.NeedSyntax |
    68  				packages.NeedTypes |
    69  				packages.NeedTypesInfo,
    70  		}
    71  		pkgs, err := packages.Load(cfg, fct.Path)
    72  		if err != nil {
    73  			return nil, fmt.Errorf("genroot: could not load package of %q %s: %w", fct.Path, fct.Name, err)
    74  		}
    75  		var pkg *packages.Package
    76  		for _, p := range pkgs {
    77  			if p.PkgPath == fct.Path {
    78  				pkg = p
    79  				break
    80  			}
    81  		}
    82  		if pkg == nil || len(pkg.Errors) > 0 {
    83  			return nil, fmt.Errorf("genroot: could not find package %q", fct.Path)
    84  		}
    85  
    86  		var (
    87  			scope = pkg.Types.Scope()
    88  		)
    89  		obj := scope.Lookup(fct.Def)
    90  		if obj == nil {
    91  			return nil, fmt.Errorf("genroot: could not find %s in package %q", fct.Def, fct.Path)
    92  		}
    93  		ft, ok := obj.(*types.Func)
    94  		if !ok {
    95  			return nil, fmt.Errorf("genroot: object %s in package %q is not a func (%T)", fct.Def, fct.Path, obj)
    96  		}
    97  		f = ft.Type().Underlying().(*types.Signature)
    98  	}
    99  
   100  	name := fct.Name
   101  	if name == "" {
   102  		switch fct.Path {
   103  		case "":
   104  			name = genRFuncName(f)
   105  		default:
   106  			name = fct.Def + "Formula"
   107  		}
   108  	}
   109  
   110  	gen := &rfuncGen{w: w, f: f, name: name}
   111  	switch fct.Pkg {
   112  	case "go-hep.org/x/hep/groot/rtree/rfunc":
   113  		// no-op.
   114  	default:
   115  		gen.pkg = "rfunc."
   116  	}
   117  
   118  	return gen, nil
   119  }
   120  
   121  func genRFuncName(sig *types.Signature) string {
   122  	o := new(strings.Builder)
   123  	o.WriteString("Func")
   124  	basic := func(k types.BasicKind) string {
   125  		switch k {
   126  		case types.Bool:
   127  			return "Bool"
   128  		case types.Uint8:
   129  			return "U8"
   130  		case types.Uint16:
   131  			return "U16"
   132  		case types.Uint32:
   133  			return "U32"
   134  		case types.Uint64:
   135  			return "U64"
   136  		case types.Int8:
   137  			return "I8"
   138  		case types.Int16:
   139  			return "I16"
   140  		case types.Int32:
   141  			return "I32"
   142  		case types.Int64:
   143  			return "I64"
   144  		case types.Float32:
   145  			return "F32"
   146  		case types.Float64:
   147  			return "F64"
   148  		case types.String:
   149  			return "Str"
   150  		}
   151  		panic(fmt.Errorf("unhandled type kind %#v", k))
   152  	}
   153  	var code func(typ types.Type) string
   154  	code = func(typ types.Type) string {
   155  		switch typ := typ.Underlying().(type) {
   156  		case *types.Basic:
   157  			return basic(typ.Kind())
   158  		case *types.Slice:
   159  			return code(typ.Elem()) + "s"
   160  		default:
   161  			panic(fmt.Errorf("unhandled type %#v", typ))
   162  		}
   163  	}
   164  
   165  	params := sig.Params()
   166  	for i := range params.Len() {
   167  		o.WriteString(code(params.At(i).Type()))
   168  	}
   169  	res := sig.Results()
   170  	if res.Len() > 0 {
   171  		o.WriteString("To")
   172  		for i := range res.Len() {
   173  			o.WriteString(code(res.At(i).Type()))
   174  		}
   175  	}
   176  	return o.String()
   177  }
   178  
   179  func (gen *rfuncGen) Generate() error {
   180  	fct := rfuncTypeFrom(gen.name, gen.f)
   181  	tmpl := template.Must(template.New("rfunc").Funcs(
   182  		template.FuncMap{
   183  			"Pkg": func() string {
   184  				return gen.pkg
   185  			},
   186  		},
   187  	).Parse(rfuncCodeTmpl))
   188  	err := tmpl.Execute(gen.w, fct)
   189  	if err != nil {
   190  		return fmt.Errorf("genroot: could not execute template for %q: %w",
   191  			fct.Name, err,
   192  		)
   193  	}
   194  	return nil
   195  }
   196  
   197  func (gen *rfuncGen) GenerateTest(w io.Writer) error {
   198  	fct := rfuncTypeFrom(gen.name, gen.f)
   199  	tmpl := template.Must(template.New("rfunc").Funcs(
   200  		template.FuncMap{
   201  			"Pkg":  func() string { return gen.pkg },
   202  			"Out0": func() string { return fct.Out[0] },
   203  		},
   204  	).Parse(rfuncTestTmpl))
   205  	err := tmpl.Execute(w, fct)
   206  	if err != nil {
   207  		return fmt.Errorf("genroot: could not execute template for %q: %w",
   208  			fct.Name, err,
   209  		)
   210  	}
   211  	return nil
   212  }
   213  
   214  func parseExpr(x string) (*types.Signature, error) {
   215  	expr, err := parser.ParseExpr(x)
   216  	if err != nil {
   217  		return nil, fmt.Errorf("genroot: could not parse %q: %w", x, err)
   218  	}
   219  	switch expr := expr.(type) {
   220  	case *ast.FuncType:
   221  		var (
   222  			pos     token.Pos
   223  			pkg     *types.Package
   224  			par     *types.Tuple
   225  			res     *types.Tuple
   226  			sig     *types.Signature
   227  			typeFor func(typ ast.Expr) types.Type
   228  		)
   229  		typeFor = func(typ ast.Expr) types.Type {
   230  			switch typ := typ.(type) {
   231  			case *ast.Ident:
   232  				t, ok := astTypesToGoTypes[typ.Name]
   233  				if !ok {
   234  					panic(fmt.Errorf("unknown ast.Ident type name %q", typ.Name))
   235  				}
   236  				return t
   237  			case *ast.ArrayType:
   238  				elt := typeFor(typ.Elt)
   239  				switch typ.Len {
   240  				case nil:
   241  					return types.NewSlice(elt)
   242  				default:
   243  					sz, err := strconv.ParseInt(typ.Len.(*ast.Ident).String(), 10, 64)
   244  					if err != nil {
   245  						panic(fmt.Errorf("invalid array expression: %#v: %+v", typ, err))
   246  					}
   247  					return types.NewArray(elt, sz)
   248  				}
   249  			default:
   250  				panic(fmt.Errorf("unhandled ast.Expr: %#v (%T), x=%q", typ, typ, x))
   251  			}
   252  		}
   253  		mk := func(lst *ast.FieldList) *types.Tuple {
   254  			vs := make([]*types.Var, lst.NumFields())
   255  			ns := make([]string, 0, len(vs))
   256  			ts := make([]ast.Expr, 0, len(vs))
   257  			for i, vs := range lst.List {
   258  				switch len(vs.Names) {
   259  				case 0:
   260  					ns = append(ns, fmt.Sprintf("arg%02d", i))
   261  					ts = append(ts, vs.Type)
   262  				default:
   263  					for _, n := range vs.Names {
   264  						ts = append(ts, vs.Type)
   265  						ns = append(ns, n.Name)
   266  					}
   267  				}
   268  			}
   269  			for i, v := range ns {
   270  				vs[i] = types.NewVar(pos, pkg, v, typeFor(ts[i]))
   271  			}
   272  			return types.NewTuple(vs...)
   273  		}
   274  		par = mk(expr.Params)
   275  		res = mk(expr.Results)
   276  		sig = types.NewSignatureType(nil, nil, nil, par, res, false)
   277  		return sig, nil
   278  	default:
   279  		panic(fmt.Errorf("error: expr=%T", expr))
   280  	}
   281  }
   282  
   283  var (
   284  	astTypesToGoTypes = map[string]types.Type{
   285  		"bool":    types.Typ[types.Bool],
   286  		"byte":    types.Typ[types.Byte],
   287  		"uint8":   types.Typ[types.Uint8],
   288  		"uint16":  types.Typ[types.Uint16],
   289  		"uint32":  types.Typ[types.Uint32],
   290  		"uint64":  types.Typ[types.Uint64],
   291  		"int8":    types.Typ[types.Int8],
   292  		"int16":   types.Typ[types.Int16],
   293  		"int32":   types.Typ[types.Int32],
   294  		"int64":   types.Typ[types.Int64],
   295  		"uint":    types.Typ[types.Uint],
   296  		"int":     types.Typ[types.Int],
   297  		"float32": types.Typ[types.Float32],
   298  		"float64": types.Typ[types.Float64],
   299  		"string":  types.Typ[types.String],
   300  	}
   301  )
   302  
   303  type rfuncType struct {
   304  	Name string
   305  	In   []string
   306  	Out  []string
   307  }
   308  
   309  func rfuncTypeFrom(name string, sig *types.Signature) rfuncType {
   310  	var (
   311  		ps  = sig.Params()
   312  		rs  = sig.Results()
   313  		fct = rfuncType{
   314  			Name: name,
   315  			In:   make([]string, ps.Len()),
   316  			Out:  make([]string, rs.Len()),
   317  		}
   318  	)
   319  
   320  	for i := range fct.In {
   321  		fct.In[i] = ps.At(i).Type().String()
   322  	}
   323  
   324  	for i := range fct.Out {
   325  		fct.Out[i] = rs.At(i).Type().String()
   326  	}
   327  
   328  	return fct
   329  }
   330  
   331  func (f rfuncType) NumIn() int   { return len(f.In) }
   332  func (f rfuncType) NumOut() int  { return len(f.Out) }
   333  func (f rfuncType) Type() string { return f.Name }
   334  
   335  func (f rfuncType) Func() string {
   336  	sig := new(strings.Builder)
   337  	sig.WriteString("func(")
   338  	for i, typ := range f.In {
   339  		if i > 0 {
   340  			sig.WriteString(", ")
   341  		}
   342  		fmt.Fprintf(sig, "arg%02d %s", i, typ)
   343  	}
   344  	sig.WriteString(")")
   345  
   346  	sig.WriteString(f.Return())
   347  
   348  	return sig.String()
   349  }
   350  
   351  func (f rfuncType) Return() string {
   352  	sig := new(strings.Builder)
   353  	switch len(f.Out) {
   354  	case 0:
   355  		// no-op
   356  	case 1:
   357  		sig.WriteString(" ")
   358  	default:
   359  		sig.WriteString(" (")
   360  	}
   361  	for i, typ := range f.Out {
   362  		if i > 0 {
   363  			sig.WriteString(", ")
   364  		}
   365  		sig.WriteString(typ)
   366  	}
   367  	switch len(f.Out) {
   368  	case 0, 1:
   369  		// no-op
   370  	default:
   371  		sig.WriteString(")")
   372  	}
   373  
   374  	return sig.String()
   375  }
   376  
   377  func (f rfuncType) TestFunc() string {
   378  	switch f.Out[0] {
   379  	case "string":
   380  		return `"42"`
   381  	case "bool":
   382  		return "true"
   383  	case "[]float64":
   384  		return "[]float64{42}"
   385  	default:
   386  		return "42"
   387  	}
   388  }
   389  
   390  const rfuncCodeTmpl = `// {{.Type}} implements rfunc.Formula
   391  type {{.Type}} struct {
   392  {{- if gt .NumIn 0}}
   393  	rvars []string
   394  {{- end}}
   395  {{- range $i, $typ := .In}}
   396  	arg{{$i}} *{{$typ}}
   397  {{- end}}
   398  	fct {{.Func}}
   399  }
   400  
   401  // New{{.Type}} return a new formula, from the provided function.
   402  func New{{.Type}}(rvars []string, fct {{.Func}}) *{{.Type}} {
   403  	return &{{.Type}}{
   404  {{- if gt .NumIn 0}}
   405  		rvars: rvars,
   406  {{- end}}
   407  		fct: fct,
   408  	}
   409  }
   410  
   411  {{if gt .NumIn 0}}
   412  // RVars implements rfunc.Formula
   413  func (f *{{.Type}}) RVars() []string { return f.rvars }
   414  {{else}}
   415  // RVars implements rfunc.Formula
   416  func (f *{{.Type}}) RVars() []string { return nil }
   417  {{end}}
   418  
   419  // Bind implements rfunc.Formula
   420  func (f *{{.Type}}) Bind(args []any) error {
   421  	if got, want := len(args), {{.NumIn}}; got != want {
   422  		return fmt.Errorf(
   423  			"rfunc: invalid number of bind arguments (got=%d, want=%d)",
   424  			got, want,
   425  		)
   426  	}
   427  {{- range $i, $typ := .In}}
   428  	{
   429  		ptr, ok := args[{{$i}}].(*{{$typ}})
   430  		if !ok {
   431  			return fmt.Errorf(
   432  				"rfunc: argument type {{$i}} (name=%s) mismatch: got=%T, want=*{{$typ}}",
   433  				f.rvars[{{$i}}], args[{{$i}}],
   434  			)
   435  		}
   436  		f.arg{{$i}} = ptr
   437  	}
   438  {{- end}}
   439  	return nil
   440  }
   441  
   442  // Func implements rfunc.Formula
   443  func (f *{{.Type}}) Func() any {
   444  	return func() {{.Return}} {
   445  		return f.fct(
   446  {{- range $i, $typ := .In}}
   447  			*f.arg{{$i}},
   448  {{- end}}
   449  		)
   450  	}
   451  }
   452  
   453  var (
   454  	_ {{Pkg}}Formula = (*{{.Type}})(nil)
   455  )
   456  `
   457  
   458  const rfuncTestTmpl = `func Test{{.Type}}(t *testing.T) {
   459  {{if gt .NumIn 0}}
   460  	rvars := make([]string, {{.NumIn}})
   461  {{- else}}
   462  	var rvars []string
   463  {{- end}}
   464  {{- range $i, $typ := .In}}
   465  	rvars[{{$i}}] = "name-{{$i}}"
   466  {{- end}}
   467  
   468  	fct := {{.Func}} {
   469  		return {{.TestFunc}}
   470  	}
   471  
   472  	form := New{{.Type}}(rvars, fct)
   473  
   474  	if got, want := form.RVars(), rvars; !reflect.DeepEqual(got, want) {
   475  		t.Fatalf("invalid rvars: got=%#v, want=%#v", got, want)
   476  	}
   477  
   478  {{if gt .NumIn 0}}
   479  	ptrs := make([]any, {{.NumIn}})
   480  {{- range $i, $typ := .In}}
   481  	ptrs[{{$i}}] = new({{$typ}})
   482  {{- end}}
   483  {{else}}
   484  	var ptrs []any
   485  {{- end}}
   486  
   487  {{if gt .NumIn 0}}
   488  	{
   489  		bad := make([]any, len(ptrs))
   490  		copy(bad, ptrs)
   491  		for i := len(ptrs)-1; i >= 0; i-- {
   492  			bad[i] = any(nil)
   493  			err := form.Bind(bad)
   494  			if err == nil {
   495  				t.Fatalf("expected an error for empty iface")
   496  			}
   497  		}
   498  		bad = append(bad, any(nil))
   499  		err := form.Bind(bad)
   500  		if err == nil {
   501  			t.Fatalf("expected an error for invalid args length")
   502  		}
   503  	}
   504  {{- else}}
   505  	{
   506  		bad := make([]any, 1)
   507  		err := form.Bind(bad)
   508  		if err == nil {
   509  			t.Fatalf("expected an error for invalid args length")
   510  		}
   511  	}
   512  {{- end}}
   513  
   514  	err := form.Bind(ptrs)
   515  	if err != nil {
   516  		t.Fatalf("could not bind formula: %+v", err)
   517  	}
   518  
   519  	got := form.Func().(func () {{.Return}})()
   520  	if got, want := got, {{Out0}}({{.TestFunc}}); !reflect.DeepEqual(got, want) {
   521  		t.Fatalf("invalid output:\ngot= %v (%T)\nwant=%v (%T)", got, got, want, want)
   522  	}
   523  }
   524  `