go-hep.org/x/hep@v0.38.1/brio/cmd/brio-gen/internal/gen/gen.go (about) 1 // Copyright ©2016 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gen // import "go-hep.org/x/hep/brio/cmd/brio-gen/internal/gen" 6 7 import ( 8 "bytes" 9 "fmt" 10 "go/format" 11 "go/types" 12 "log" 13 "strings" 14 15 "golang.org/x/tools/go/packages" 16 ) 17 18 var ( 19 binMa *types.Interface // encoding.BinaryMarshaler 20 binUn *types.Interface // encoding.BinaryUnmarshaler 21 ) 22 23 // Generator holds the state of the generation. 24 type Generator struct { 25 buf *bytes.Buffer 26 pkg *types.Package 27 28 // set of imported packages. 29 // usually: "encoding/binary", "math" 30 imps map[string]int 31 32 Verbose bool // enable verbose mode 33 } 34 35 // NewGenerator returns a new code generator for package p, 36 // where p is the package's import path. 37 func NewGenerator(p string) (*Generator, error) { 38 pkg, err := importPkg(p) 39 if err != nil { 40 return nil, err 41 } 42 43 return &Generator{ 44 buf: new(bytes.Buffer), 45 pkg: pkg, 46 imps: map[string]int{"encoding/binary": 1}, 47 }, nil 48 } 49 50 func (g *Generator) printf(format string, args ...any) { 51 fmt.Fprintf(g.buf, format, args...) 52 } 53 54 func (g *Generator) Generate(typeName string) { 55 scope := g.pkg.Scope() 56 obj := scope.Lookup(typeName) 57 if obj == nil { 58 log.Fatalf("no such type %q in package %q\n", typeName, g.pkg.Path()+"/"+g.pkg.Name()) 59 } 60 61 tn, ok := obj.(*types.TypeName) 62 if !ok { 63 log.Fatalf("%q is not a type (%v)\n", typeName, obj) 64 } 65 66 typ, ok := tn.Type().Underlying().(*types.Struct) 67 if !ok { 68 log.Fatalf("%q is not a named struct (%v)\n", typeName, tn) 69 } 70 if g.Verbose { 71 log.Printf("typ: %+v\n", typ) 72 } 73 74 g.genMarshal(typ, typeName) 75 g.genUnmarshal(typ, typeName) 76 } 77 78 func (g *Generator) genMarshal(t types.Type, typeName string) { 79 g.printf(`// MarshalBinary implements encoding.BinaryMarshaler 80 func (o *%[1]s) MarshalBinary() (data []byte, err error) { 81 var buf [8]byte 82 `, 83 typeName, 84 ) 85 86 typ := t.Underlying().(*types.Struct) 87 for i := range typ.NumFields() { 88 ft := typ.Field(i) 89 g.genMarshalType(ft.Type(), "o."+ft.Name()) 90 } 91 92 g.printf("return data, err\n}\n\n") 93 } 94 95 func (g *Generator) genMarshalType(t types.Type, n string) { 96 if types.Implements(t, binMa) || types.Implements(types.NewPointer(t), binMa) { 97 g.printf("{\nsub, err := %s.MarshalBinary()\n", n) 98 g.printf("if err != nil {\nreturn nil, err\n}\n") 99 g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(len(sub)))\n") 100 g.printf("data = append(data, buf[:8]...)\n") 101 g.printf("data = append(data, sub...)\n") 102 g.printf("}\n") 103 return 104 } 105 106 ut := t.Underlying() 107 switch ut := ut.(type) { 108 case *types.Basic: 109 switch kind := ut.Kind(); kind { 110 111 case types.Bool: 112 g.printf("switch %s {\ncase false:\n data = append(data, uint8(0))\n", n) 113 g.printf("default:\ndata = append(data, uint8(1))\n}\n") 114 115 case types.Uint: 116 g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n", n) 117 g.printf("data = append(data, buf[:8]...)\n") 118 119 case types.Uint8: 120 g.printf("data = append(data, byte(%s))\n", n) 121 122 case types.Uint16: 123 g.printf( 124 "binary.LittleEndian.PutUint16(buf[:2], uint16(%s))\n", 125 n, 126 ) 127 g.printf("data = append(data, buf[:2]...)\n") 128 129 case types.Uint32: 130 g.printf( 131 "binary.LittleEndian.PutUint32(buf[:4], uint32(%s))\n", 132 n, 133 ) 134 g.printf("data = append(data, buf[:4]...)\n") 135 136 case types.Uint64: 137 g.printf( 138 "binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n", 139 n, 140 ) 141 g.printf("data = append(data, buf[:8]...)\n") 142 143 case types.Int: 144 g.printf( 145 "binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n", 146 n, 147 ) 148 g.printf("data = append(data, buf[:8]...)\n") 149 150 case types.Int8: 151 g.printf("data = append(data, byte(%s))\n", n) 152 153 case types.Int16: 154 g.printf( 155 "binary.LittleEndian.PutUint16(buf[:2], uint16(%s))\n", 156 n, 157 ) 158 g.printf("data = append(data, buf[:2]...)\n") 159 160 case types.Int32: 161 g.printf( 162 "binary.LittleEndian.PutUint32(buf[:4], uint32(%s))\n", 163 n, 164 ) 165 g.printf("data = append(data, buf[:4]...)\n") 166 167 case types.Int64: 168 g.printf( 169 "binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n", 170 n, 171 ) 172 g.printf("data = append(data, buf[:8]...)\n") 173 174 case types.Float32: 175 g.imps["math"] = 1 176 g.printf( 177 "binary.LittleEndian.PutUint32(buf[:4], math.Float32bits(%s))\n", 178 n, 179 ) 180 g.printf("data = append(data, buf[:4]...)\n") 181 182 case types.Float64: 183 g.imps["math"] = 1 184 g.printf( 185 "binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(%s))\n", 186 n, 187 ) 188 g.printf("data = append(data, buf[:8]...)\n") 189 190 case types.Complex64: 191 g.imps["math"] = 1 192 g.printf( 193 "binary.LittleEndian.PutUint64(buf[:4], math.Float32bits(real(%s)))\n", 194 n, 195 ) 196 g.printf("data = append(data, buf[:4]...)\n") 197 g.printf( 198 "binary.LittleEndian.PutUint64(buf[:4], math.Float32bits(imag(%s)))\n", 199 n, 200 ) 201 g.printf("data = append(data, buf[:4]...)\n") 202 203 case types.Complex128: 204 g.imps["math"] = 1 205 g.printf( 206 "binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(real(%s)))\n", 207 n, 208 ) 209 g.printf("data = append(data, buf[:8]...)\n") 210 g.printf( 211 "binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(imag(%s)))\n", 212 n, 213 ) 214 g.printf("data = append(data, buf[:8]...)\n") 215 216 case types.String: 217 g.printf( 218 "binary.LittleEndian.PutUint64(buf[:8], uint64(len(%s)))\n", 219 n, 220 ) 221 g.printf("data = append(data, buf[:8]...)\n") 222 g.printf("data = append(data, []byte(%s)...)\n", n) 223 224 default: 225 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 226 } 227 228 case *types.Struct: 229 switch t.(type) { 230 case *types.Named: 231 g.printf("{\nsub, err := %s.MarshalBinary()\n", n) 232 g.printf("if err != nil {\nreturn nil, err\n}\n") 233 g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(len(sub)))\n") 234 g.printf("data = append(data, buf[:8]...)\n") 235 g.printf("data = append(data, sub...)\n") 236 g.printf("}\n") 237 default: 238 // un-named 239 for i := range ut.NumFields() { 240 elem := ut.Field(i) 241 g.genMarshalType(elem.Type(), n+"."+elem.Name()) 242 } 243 } 244 245 case *types.Array: 246 if isByteType(ut.Elem()) { 247 g.printf("data = append(data, %s[:]...)\n", n) 248 } else { 249 g.printf("for i := range %s {\n", n) 250 if _, ok := ut.Elem().(*types.Pointer); ok { 251 g.printf("o := %s[i]\n", n) 252 } else { 253 g.printf("o := &%s[i]\n", n) 254 } 255 g.genMarshalType(ut.Elem(), "o") 256 g.printf("}\n") 257 } 258 259 case *types.Slice: 260 g.printf( 261 "binary.LittleEndian.PutUint64(buf[:8], uint64(len(%s)))\n", 262 n, 263 ) 264 g.printf("data = append(data, buf[:8]...)\n") 265 if isByteType(ut.Elem()) { 266 g.printf("data = append(data, %s...)\n", n) 267 } else { 268 g.printf("for i := range %s {\n", n) 269 if _, ok := ut.Elem().(*types.Pointer); ok { 270 g.printf("o := %s[i]\n", n) 271 } else { 272 g.printf("o := &%s[i]\n", n) 273 } 274 g.genMarshalType(ut.Elem(), "o") 275 g.printf("}\n") 276 } 277 278 case *types.Pointer: 279 g.printf("{\n") 280 g.printf("v := *%s\n", n) 281 g.genMarshalType(ut.Elem(), "v") 282 g.printf("}\n") 283 284 case *types.Interface: 285 log.Fatalf("marshal interface not supported (type=%v)\n", t) 286 287 default: 288 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 289 } 290 } 291 292 func (g *Generator) genUnmarshal(t types.Type, typeName string) { 293 g.printf(`// UnmarshalBinary implements encoding.BinaryUnmarshaler 294 func (o *%[1]s) UnmarshalBinary(data []byte) (err error) { 295 `, 296 typeName, 297 ) 298 299 typ := t.Underlying().(*types.Struct) 300 for i := range typ.NumFields() { 301 ft := typ.Field(i) 302 g.genUnmarshalType(ft.Type(), "o."+ft.Name()) 303 } 304 305 g.printf("_ = data\n") 306 g.printf("return err\n}\n\n") 307 } 308 309 func (g *Generator) genUnmarshalType(t types.Type, n string) { 310 if types.Implements(t, binUn) || types.Implements(types.NewPointer(t), binUn) { 311 g.printf("{\n") 312 g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n") 313 g.printf("data = data[8:]\n") 314 g.printf("err = %s.UnmarshalBinary(data[:n])\n", n) 315 g.printf("if err != nil {\nreturn err\n}\n") 316 g.printf("data = data[n:]\n") 317 g.printf("}\n") 318 return 319 } 320 321 tn := types.TypeString(t, types.RelativeTo(g.pkg)) 322 ut := t.Underlying() 323 switch ut := ut.(type) { 324 case *types.Basic: 325 switch kind := ut.Kind(); kind { 326 327 case types.Bool: 328 g.printf("switch data[i] {\ncase 0:\n%s = false\n", n) 329 g.printf("default:\n%s = true\n}\n", n) 330 g.printf("data = data[1:]\n") 331 332 case types.Uint: 333 g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn) 334 g.printf("data = data[8:]\n") 335 336 case types.Uint8: 337 g.printf("%s = %s(data[0])\n", n, tn) 338 g.printf("data = data[1:]\n") 339 340 case types.Uint16: 341 g.printf("%s = %s(binary.LittleEndian.Uint16(data[:2]))\n", n, tn) 342 g.printf("data = data[2:]\n") 343 344 case types.Uint32: 345 g.printf("%s = %s(binary.LittleEndian.Uint32(data[:4]))\n", n, tn) 346 g.printf("data = data[4:]\n") 347 348 case types.Uint64: 349 g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn) 350 g.printf("data = data[8:]\n") 351 352 case types.Int: 353 g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn) 354 g.printf("data = data[8:]\n") 355 356 case types.Int8: 357 g.printf("%s = %s(data[0])\n", n, tn) 358 g.printf("data = data[1:]\n") 359 360 case types.Int16: 361 g.printf("%s = %s(binary.LittleEndian.Uint16(data[:2]))\n", n, tn) 362 g.printf("data = data[2:]\n") 363 364 case types.Int32: 365 g.printf("%s = %s(binary.LittleEndian.Uint32(data[:4]))\n", n, tn) 366 g.printf("data = data[4:]\n") 367 368 case types.Int64: 369 g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn) 370 g.printf("data = data[8:]\n") 371 372 case types.Float32: 373 g.imps["math"] = 1 374 g.printf("%s = %s(math.Float32frombits(binary.LittleEndian.Uint32(data[:4])))\n", n, tn) 375 g.printf("data = data[4:]\n") 376 377 case types.Float64: 378 g.imps["math"] = 1 379 g.printf("%s = %s(math.Float64frombits(binary.LittleEndian.Uint64(data[:8])))\n", n, tn) 380 g.printf("data = data[8:]\n") 381 382 case types.Complex64: 383 g.imps["math"] = 1 384 g.printf("%s = %s(complex(math.Float32frombits(binary.LittleEndian.Uint32(data[:4])), math.Float32frombits(binary.LittleEndian.Uint32(data[4:8]))))\n", n, tn) 385 g.printf("data = data[8:]\n") 386 387 case types.Complex128: 388 g.imps["math"] = 1 389 g.printf("%s = %s(complex(math.Float64frombits(binary.LittleEndian.Uint64(data[:8])), math.Float64frombits(binary.LittleEndian.Uint64(data[8:16]))))\n", n, tn) 390 g.printf("data = data[16:]\n") 391 392 case types.String: 393 g.printf("{\n") 394 g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n") 395 g.printf("data = data[8:]\n") 396 g.printf("%s = %s(data[:n])\n", n, tn) 397 g.printf("data = data[n:]\n") 398 g.printf("}\n") 399 400 default: 401 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 402 } 403 404 case *types.Struct: 405 switch t.(type) { 406 case *types.Named: 407 g.printf("{\n") 408 g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n") 409 g.printf("data = data[8:]\n") 410 g.printf("err = %s.UnmarshalBinary(data[:n])\n", n) 411 g.printf("if err != nil {\nreturn err\n}\n") 412 g.printf("data = data[n:]\n") 413 g.printf("}\n") 414 default: 415 // un-named. 416 for i := range ut.NumFields() { 417 elem := ut.Field(i) 418 g.genUnmarshalType(elem.Type(), n+"."+elem.Name()) 419 } 420 } 421 422 case *types.Array: 423 if isByteType(ut.Elem()) { 424 g.printf("copy(%s[:], data[:n])\n", n) 425 g.printf("data = data[n:]\n") 426 } else { 427 g.printf("for i := range %s {\n", n) 428 nn := n + "[i]" 429 if pt, ok := ut.Elem().(*types.Pointer); ok { 430 g.printf("var oi %s\n", qualTypeName(pt.Elem(), g.pkg)) 431 nn = "oi" 432 } 433 if _, ok := ut.Elem().Underlying().(*types.Struct); ok { 434 g.printf("oi := &%s[i]\n", n) 435 nn = "oi" 436 } 437 g.genUnmarshalType(ut.Elem(), nn) 438 if _, ok := ut.Elem().(*types.Pointer); ok { 439 g.printf("%s[i] = oi\n", n) 440 } 441 g.printf("}\n") 442 } 443 444 case *types.Slice: 445 g.printf("{\n") 446 g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n") 447 g.printf("%[1]s = make([]%[2]s, n)\n", n, qualTypeName(ut.Elem(), g.pkg)) 448 g.printf("data = data[8:]\n") 449 if isByteType(ut.Elem()) { 450 g.printf("%[1]s = append(%[1]s, data[:n]...)\n", n) 451 g.printf("data = data[n:]\n") 452 } else { 453 g.printf("for i := range %s {\n", n) 454 nn := n + "[i]" 455 if pt, ok := ut.Elem().(*types.Pointer); ok { 456 g.printf("var oi %s\n", qualTypeName(pt.Elem(), g.pkg)) 457 nn = "oi" 458 } 459 if _, ok := ut.Elem().Underlying().(*types.Struct); ok { 460 g.printf("oi := &%s[i]\n", n) 461 nn = "oi" 462 } 463 g.genUnmarshalType(ut.Elem(), nn) 464 if _, ok := ut.Elem().(*types.Pointer); ok { 465 g.printf("%s[i] = oi\n", n) 466 } 467 g.printf("}\n") 468 } 469 g.printf("}\n") 470 471 case *types.Pointer: 472 g.printf("{\n") 473 elt := ut.Elem() 474 g.printf("var v %s\n", qualTypeName(elt, g.pkg)) 475 g.genUnmarshalType(elt, "v") 476 g.printf("%s = &v\n\n", n) 477 g.printf("}\n") 478 479 case *types.Interface: 480 log.Fatalf("marshal interface not supported (type=%v)\n", t) 481 482 default: 483 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 484 } 485 486 } 487 488 func isByteType(t types.Type) bool { 489 b, ok := t.Underlying().(*types.Basic) 490 if !ok { 491 return false 492 } 493 return b.Kind() == types.Byte 494 } 495 496 func qualTypeName(t types.Type, pkg *types.Package) string { 497 n := types.TypeString(t, types.RelativeTo(pkg)) 498 i := strings.LastIndex(n, "/") 499 if i < 0 { 500 return n 501 } 502 return string(n[i+1:]) 503 } 504 505 func (g *Generator) Format() ([]byte, error) { 506 buf := new(bytes.Buffer) 507 508 // See standard at https://golang.org/s/generatedcode 509 buf.WriteString(fmt.Sprintf(`// Code generated by %[1]s; DO NOT EDIT. 510 511 package %[2]s 512 513 import ( 514 "encoding/binary" 515 `, 516 "brio-gen", 517 g.pkg.Name(), 518 )) 519 520 for k := range g.imps { 521 fmt.Fprintf(buf, "%q\n", k) 522 } 523 fmt.Fprintf(buf, ")\n\n") 524 525 buf.Write(g.buf.Bytes()) 526 527 src, err := format.Source(buf.Bytes()) 528 if err != nil { 529 log.Printf("=== error ===\n%s\n", buf.Bytes()) 530 } 531 return src, err 532 } 533 534 func importPkg(p string) (*types.Package, error) { 535 cfg := &packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedTypesSizes | packages.NeedDeps} 536 pkgs, err := packages.Load(cfg, p) 537 if err != nil { 538 return nil, fmt.Errorf("could not load package %q: %w", p, err) 539 } 540 541 return pkgs[0].Types, nil 542 } 543 544 func init() { 545 pkg, err := importPkg("encoding") 546 if err != nil { 547 log.Fatalf("error finding package \"encoding\": %v\n", err) 548 } 549 550 o := pkg.Scope().Lookup("BinaryMarshaler") 551 if o == nil { 552 log.Fatalf("could not find interface encoding.BinaryMarshaler\n") 553 } 554 binMa = o.(*types.TypeName).Type().Underlying().(*types.Interface) 555 556 o = pkg.Scope().Lookup("BinaryUnmarshaler") 557 if o == nil { 558 log.Fatalf("could not find interface encoding.BinaryUnmarshaler\n") 559 } 560 binUn = o.(*types.TypeName).Type().Underlying().(*types.Interface) 561 }