go-hep.org/x/hep@v0.38.1/groot/internal/genroot/genrfunc.go (about) 1 // Copyright ©2020 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package genroot // import "go-hep.org/x/hep/groot/internal/genroot" 6 7 import ( 8 "fmt" 9 "go/ast" 10 "go/parser" 11 "go/token" 12 "go/types" 13 "io" 14 "strconv" 15 "strings" 16 "text/template" 17 18 "golang.org/x/tools/go/packages" 19 ) 20 21 // RFunc describes which function should be used as a template 22 // to implement the rtree/rfunc.Formula interface. 23 type RFunc struct { 24 Pkg string // Name of package hosting the formula to be generated. 25 Path string // Import path of the package holding the function. 26 Name string // Formula name. 27 Def string // Function name or signature. 28 } 29 30 // GenRFunc generates the rtree/rfunc.Formula implementation for fct. 31 func GenRFunc(w io.Writer, fct RFunc) error { 32 gen, err := NewRFuncGenerator(w, fct) 33 if err != nil { 34 return fmt.Errorf("genroot: could not create rfunc generator: %w", err) 35 } 36 37 err = gen.Generate() 38 if err != nil { 39 return fmt.Errorf("genroot: could not generate rfunc formula implementation: %w", err) 40 } 41 return nil 42 } 43 44 type rfuncGen struct { 45 w io.Writer 46 f *types.Signature 47 pkg string // "rfunc." or "" 48 name string 49 } 50 51 func NewRFuncGenerator(w io.Writer, fct RFunc) (*rfuncGen, error) { 52 var ( 53 f *types.Signature 54 err error 55 ) 56 switch fct.Path { 57 case "": 58 f, err = parseExpr(fct.Def) 59 if err != nil { 60 return nil, fmt.Errorf("genroot: could not parse function signature: %w", err) 61 } 62 default: 63 cfg := &packages.Config{ 64 Mode: packages.NeedName | 65 packages.NeedFiles | 66 packages.NeedCompiledGoFiles | 67 packages.NeedSyntax | 68 packages.NeedTypes | 69 packages.NeedTypesInfo, 70 } 71 pkgs, err := packages.Load(cfg, fct.Path) 72 if err != nil { 73 return nil, fmt.Errorf("genroot: could not load package of %q %s: %w", fct.Path, fct.Name, err) 74 } 75 var pkg *packages.Package 76 for _, p := range pkgs { 77 if p.PkgPath == fct.Path { 78 pkg = p 79 break 80 } 81 } 82 if pkg == nil || len(pkg.Errors) > 0 { 83 return nil, fmt.Errorf("genroot: could not find package %q", fct.Path) 84 } 85 86 var ( 87 scope = pkg.Types.Scope() 88 ) 89 obj := scope.Lookup(fct.Def) 90 if obj == nil { 91 return nil, fmt.Errorf("genroot: could not find %s in package %q", fct.Def, fct.Path) 92 } 93 ft, ok := obj.(*types.Func) 94 if !ok { 95 return nil, fmt.Errorf("genroot: object %s in package %q is not a func (%T)", fct.Def, fct.Path, obj) 96 } 97 f = ft.Type().Underlying().(*types.Signature) 98 } 99 100 name := fct.Name 101 if name == "" { 102 switch fct.Path { 103 case "": 104 name = genRFuncName(f) 105 default: 106 name = fct.Def + "Formula" 107 } 108 } 109 110 gen := &rfuncGen{w: w, f: f, name: name} 111 switch fct.Pkg { 112 case "go-hep.org/x/hep/groot/rtree/rfunc": 113 // no-op. 114 default: 115 gen.pkg = "rfunc." 116 } 117 118 return gen, nil 119 } 120 121 func genRFuncName(sig *types.Signature) string { 122 o := new(strings.Builder) 123 o.WriteString("Func") 124 basic := func(k types.BasicKind) string { 125 switch k { 126 case types.Bool: 127 return "Bool" 128 case types.Uint8: 129 return "U8" 130 case types.Uint16: 131 return "U16" 132 case types.Uint32: 133 return "U32" 134 case types.Uint64: 135 return "U64" 136 case types.Int8: 137 return "I8" 138 case types.Int16: 139 return "I16" 140 case types.Int32: 141 return "I32" 142 case types.Int64: 143 return "I64" 144 case types.Float32: 145 return "F32" 146 case types.Float64: 147 return "F64" 148 case types.String: 149 return "Str" 150 } 151 panic(fmt.Errorf("unhandled type kind %#v", k)) 152 } 153 var code func(typ types.Type) string 154 code = func(typ types.Type) string { 155 switch typ := typ.Underlying().(type) { 156 case *types.Basic: 157 return basic(typ.Kind()) 158 case *types.Slice: 159 return code(typ.Elem()) + "s" 160 default: 161 panic(fmt.Errorf("unhandled type %#v", typ)) 162 } 163 } 164 165 params := sig.Params() 166 for i := range params.Len() { 167 o.WriteString(code(params.At(i).Type())) 168 } 169 res := sig.Results() 170 if res.Len() > 0 { 171 o.WriteString("To") 172 for i := range res.Len() { 173 o.WriteString(code(res.At(i).Type())) 174 } 175 } 176 return o.String() 177 } 178 179 func (gen *rfuncGen) Generate() error { 180 fct := rfuncTypeFrom(gen.name, gen.f) 181 tmpl := template.Must(template.New("rfunc").Funcs( 182 template.FuncMap{ 183 "Pkg": func() string { 184 return gen.pkg 185 }, 186 }, 187 ).Parse(rfuncCodeTmpl)) 188 err := tmpl.Execute(gen.w, fct) 189 if err != nil { 190 return fmt.Errorf("genroot: could not execute template for %q: %w", 191 fct.Name, err, 192 ) 193 } 194 return nil 195 } 196 197 func (gen *rfuncGen) GenerateTest(w io.Writer) error { 198 fct := rfuncTypeFrom(gen.name, gen.f) 199 tmpl := template.Must(template.New("rfunc").Funcs( 200 template.FuncMap{ 201 "Pkg": func() string { return gen.pkg }, 202 "Out0": func() string { return fct.Out[0] }, 203 }, 204 ).Parse(rfuncTestTmpl)) 205 err := tmpl.Execute(w, fct) 206 if err != nil { 207 return fmt.Errorf("genroot: could not execute template for %q: %w", 208 fct.Name, err, 209 ) 210 } 211 return nil 212 } 213 214 func parseExpr(x string) (*types.Signature, error) { 215 expr, err := parser.ParseExpr(x) 216 if err != nil { 217 return nil, fmt.Errorf("genroot: could not parse %q: %w", x, err) 218 } 219 switch expr := expr.(type) { 220 case *ast.FuncType: 221 var ( 222 pos token.Pos 223 pkg *types.Package 224 par *types.Tuple 225 res *types.Tuple 226 sig *types.Signature 227 typeFor func(typ ast.Expr) types.Type 228 ) 229 typeFor = func(typ ast.Expr) types.Type { 230 switch typ := typ.(type) { 231 case *ast.Ident: 232 t, ok := astTypesToGoTypes[typ.Name] 233 if !ok { 234 panic(fmt.Errorf("unknown ast.Ident type name %q", typ.Name)) 235 } 236 return t 237 case *ast.ArrayType: 238 elt := typeFor(typ.Elt) 239 switch typ.Len { 240 case nil: 241 return types.NewSlice(elt) 242 default: 243 sz, err := strconv.ParseInt(typ.Len.(*ast.Ident).String(), 10, 64) 244 if err != nil { 245 panic(fmt.Errorf("invalid array expression: %#v: %+v", typ, err)) 246 } 247 return types.NewArray(elt, sz) 248 } 249 default: 250 panic(fmt.Errorf("unhandled ast.Expr: %#v (%T), x=%q", typ, typ, x)) 251 } 252 } 253 mk := func(lst *ast.FieldList) *types.Tuple { 254 vs := make([]*types.Var, lst.NumFields()) 255 ns := make([]string, 0, len(vs)) 256 ts := make([]ast.Expr, 0, len(vs)) 257 for i, vs := range lst.List { 258 switch len(vs.Names) { 259 case 0: 260 ns = append(ns, fmt.Sprintf("arg%02d", i)) 261 ts = append(ts, vs.Type) 262 default: 263 for _, n := range vs.Names { 264 ts = append(ts, vs.Type) 265 ns = append(ns, n.Name) 266 } 267 } 268 } 269 for i, v := range ns { 270 vs[i] = types.NewVar(pos, pkg, v, typeFor(ts[i])) 271 } 272 return types.NewTuple(vs...) 273 } 274 par = mk(expr.Params) 275 res = mk(expr.Results) 276 sig = types.NewSignatureType(nil, nil, nil, par, res, false) 277 return sig, nil 278 default: 279 panic(fmt.Errorf("error: expr=%T", expr)) 280 } 281 } 282 283 var ( 284 astTypesToGoTypes = map[string]types.Type{ 285 "bool": types.Typ[types.Bool], 286 "byte": types.Typ[types.Byte], 287 "uint8": types.Typ[types.Uint8], 288 "uint16": types.Typ[types.Uint16], 289 "uint32": types.Typ[types.Uint32], 290 "uint64": types.Typ[types.Uint64], 291 "int8": types.Typ[types.Int8], 292 "int16": types.Typ[types.Int16], 293 "int32": types.Typ[types.Int32], 294 "int64": types.Typ[types.Int64], 295 "uint": types.Typ[types.Uint], 296 "int": types.Typ[types.Int], 297 "float32": types.Typ[types.Float32], 298 "float64": types.Typ[types.Float64], 299 "string": types.Typ[types.String], 300 } 301 ) 302 303 type rfuncType struct { 304 Name string 305 In []string 306 Out []string 307 } 308 309 func rfuncTypeFrom(name string, sig *types.Signature) rfuncType { 310 var ( 311 ps = sig.Params() 312 rs = sig.Results() 313 fct = rfuncType{ 314 Name: name, 315 In: make([]string, ps.Len()), 316 Out: make([]string, rs.Len()), 317 } 318 ) 319 320 for i := range fct.In { 321 fct.In[i] = ps.At(i).Type().String() 322 } 323 324 for i := range fct.Out { 325 fct.Out[i] = rs.At(i).Type().String() 326 } 327 328 return fct 329 } 330 331 func (f rfuncType) NumIn() int { return len(f.In) } 332 func (f rfuncType) NumOut() int { return len(f.Out) } 333 func (f rfuncType) Type() string { return f.Name } 334 335 func (f rfuncType) Func() string { 336 sig := new(strings.Builder) 337 sig.WriteString("func(") 338 for i, typ := range f.In { 339 if i > 0 { 340 sig.WriteString(", ") 341 } 342 fmt.Fprintf(sig, "arg%02d %s", i, typ) 343 } 344 sig.WriteString(")") 345 346 sig.WriteString(f.Return()) 347 348 return sig.String() 349 } 350 351 func (f rfuncType) Return() string { 352 sig := new(strings.Builder) 353 switch len(f.Out) { 354 case 0: 355 // no-op 356 case 1: 357 sig.WriteString(" ") 358 default: 359 sig.WriteString(" (") 360 } 361 for i, typ := range f.Out { 362 if i > 0 { 363 sig.WriteString(", ") 364 } 365 sig.WriteString(typ) 366 } 367 switch len(f.Out) { 368 case 0, 1: 369 // no-op 370 default: 371 sig.WriteString(")") 372 } 373 374 return sig.String() 375 } 376 377 func (f rfuncType) TestFunc() string { 378 switch f.Out[0] { 379 case "string": 380 return `"42"` 381 case "bool": 382 return "true" 383 case "[]float64": 384 return "[]float64{42}" 385 default: 386 return "42" 387 } 388 } 389 390 const rfuncCodeTmpl = `// {{.Type}} implements rfunc.Formula 391 type {{.Type}} struct { 392 {{- if gt .NumIn 0}} 393 rvars []string 394 {{- end}} 395 {{- range $i, $typ := .In}} 396 arg{{$i}} *{{$typ}} 397 {{- end}} 398 fct {{.Func}} 399 } 400 401 // New{{.Type}} return a new formula, from the provided function. 402 func New{{.Type}}(rvars []string, fct {{.Func}}) *{{.Type}} { 403 return &{{.Type}}{ 404 {{- if gt .NumIn 0}} 405 rvars: rvars, 406 {{- end}} 407 fct: fct, 408 } 409 } 410 411 {{if gt .NumIn 0}} 412 // RVars implements rfunc.Formula 413 func (f *{{.Type}}) RVars() []string { return f.rvars } 414 {{else}} 415 // RVars implements rfunc.Formula 416 func (f *{{.Type}}) RVars() []string { return nil } 417 {{end}} 418 419 // Bind implements rfunc.Formula 420 func (f *{{.Type}}) Bind(args []any) error { 421 if got, want := len(args), {{.NumIn}}; got != want { 422 return fmt.Errorf( 423 "rfunc: invalid number of bind arguments (got=%d, want=%d)", 424 got, want, 425 ) 426 } 427 {{- range $i, $typ := .In}} 428 { 429 ptr, ok := args[{{$i}}].(*{{$typ}}) 430 if !ok { 431 return fmt.Errorf( 432 "rfunc: argument type {{$i}} (name=%s) mismatch: got=%T, want=*{{$typ}}", 433 f.rvars[{{$i}}], args[{{$i}}], 434 ) 435 } 436 f.arg{{$i}} = ptr 437 } 438 {{- end}} 439 return nil 440 } 441 442 // Func implements rfunc.Formula 443 func (f *{{.Type}}) Func() any { 444 return func() {{.Return}} { 445 return f.fct( 446 {{- range $i, $typ := .In}} 447 *f.arg{{$i}}, 448 {{- end}} 449 ) 450 } 451 } 452 453 var ( 454 _ {{Pkg}}Formula = (*{{.Type}})(nil) 455 ) 456 ` 457 458 const rfuncTestTmpl = `func Test{{.Type}}(t *testing.T) { 459 {{if gt .NumIn 0}} 460 rvars := make([]string, {{.NumIn}}) 461 {{- else}} 462 var rvars []string 463 {{- end}} 464 {{- range $i, $typ := .In}} 465 rvars[{{$i}}] = "name-{{$i}}" 466 {{- end}} 467 468 fct := {{.Func}} { 469 return {{.TestFunc}} 470 } 471 472 form := New{{.Type}}(rvars, fct) 473 474 if got, want := form.RVars(), rvars; !reflect.DeepEqual(got, want) { 475 t.Fatalf("invalid rvars: got=%#v, want=%#v", got, want) 476 } 477 478 {{if gt .NumIn 0}} 479 ptrs := make([]any, {{.NumIn}}) 480 {{- range $i, $typ := .In}} 481 ptrs[{{$i}}] = new({{$typ}}) 482 {{- end}} 483 {{else}} 484 var ptrs []any 485 {{- end}} 486 487 {{if gt .NumIn 0}} 488 { 489 bad := make([]any, len(ptrs)) 490 copy(bad, ptrs) 491 for i := len(ptrs)-1; i >= 0; i-- { 492 bad[i] = any(nil) 493 err := form.Bind(bad) 494 if err == nil { 495 t.Fatalf("expected an error for empty iface") 496 } 497 } 498 bad = append(bad, any(nil)) 499 err := form.Bind(bad) 500 if err == nil { 501 t.Fatalf("expected an error for invalid args length") 502 } 503 } 504 {{- else}} 505 { 506 bad := make([]any, 1) 507 err := form.Bind(bad) 508 if err == nil { 509 t.Fatalf("expected an error for invalid args length") 510 } 511 } 512 {{- end}} 513 514 err := form.Bind(ptrs) 515 if err != nil { 516 t.Fatalf("could not bind formula: %+v", err) 517 } 518 519 got := form.Func().(func () {{.Return}})() 520 if got, want := got, {{Out0}}({{.TestFunc}}); !reflect.DeepEqual(got, want) { 521 t.Fatalf("invalid output:\ngot= %v (%T)\nwant=%v (%T)", got, got, want, want) 522 } 523 } 524 `