go-hep.org/x/hep@v0.40.0/brio/cmd/brio-gen/internal/gen/gen.go (about) 1 // Copyright ©2016 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gen // import "go-hep.org/x/hep/brio/cmd/brio-gen/internal/gen" 6 7 import ( 8 "bytes" 9 "fmt" 10 "go/format" 11 "go/types" 12 "log" 13 "strings" 14 15 "golang.org/x/tools/go/packages" 16 ) 17 18 var ( 19 binMa *types.Interface // encoding.BinaryMarshaler 20 binUn *types.Interface // encoding.BinaryUnmarshaler 21 ) 22 23 // Generator holds the state of the generation. 24 type Generator struct { 25 buf *bytes.Buffer 26 pkg *types.Package 27 28 // set of imported packages. 29 // usually: "encoding/binary", "math" 30 imps map[string]int 31 32 Verbose bool // enable verbose mode 33 } 34 35 // NewGenerator returns a new code generator for package p, 36 // where p is the package's import path. 37 func NewGenerator(p string) (*Generator, error) { 38 pkg, err := importPkg(p) 39 if err != nil { 40 return nil, err 41 } 42 43 return &Generator{ 44 buf: new(bytes.Buffer), 45 pkg: pkg, 46 imps: map[string]int{"encoding/binary": 1}, 47 }, nil 48 } 49 50 func (g *Generator) printf(format string, args ...any) { 51 fmt.Fprintf(g.buf, format, args...) 52 } 53 54 func (g *Generator) Generate(typeName string) { 55 scope := g.pkg.Scope() 56 obj := scope.Lookup(typeName) 57 if obj == nil { 58 log.Fatalf("no such type %q in package %q\n", typeName, g.pkg.Path()+"/"+g.pkg.Name()) 59 } 60 61 tn, ok := obj.(*types.TypeName) 62 if !ok { 63 log.Fatalf("%q is not a type (%v)\n", typeName, obj) 64 } 65 66 typ, ok := tn.Type().Underlying().(*types.Struct) 67 if !ok { 68 log.Fatalf("%q is not a named struct (%v)\n", typeName, tn) 69 } 70 if g.Verbose { 71 log.Printf("typ: %+v\n", typ) 72 } 73 74 g.genMarshal(typ, typeName) 75 g.genUnmarshal(typ, typeName) 76 } 77 78 func (g *Generator) genMarshal(t types.Type, typeName string) { 79 g.printf(`// MarshalBinary implements encoding.BinaryMarshaler 80 func (o *%[1]s) MarshalBinary() (data []byte, err error) { 81 var buf [8]byte 82 `, 83 typeName, 84 ) 85 86 typ := t.Underlying().(*types.Struct) 87 for ft := range typ.Fields() { 88 g.genMarshalType(ft.Type(), "o."+ft.Name()) 89 } 90 91 g.printf("return data, err\n}\n\n") 92 } 93 94 func (g *Generator) genMarshalType(t types.Type, n string) { 95 if types.Implements(t, binMa) || types.Implements(types.NewPointer(t), binMa) { 96 g.printf("{\nsub, err := %s.MarshalBinary()\n", n) 97 g.printf("if err != nil {\nreturn nil, err\n}\n") 98 g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(len(sub)))\n") 99 g.printf("data = append(data, buf[:8]...)\n") 100 g.printf("data = append(data, sub...)\n") 101 g.printf("}\n") 102 return 103 } 104 105 ut := t.Underlying() 106 switch ut := ut.(type) { 107 case *types.Basic: 108 switch kind := ut.Kind(); kind { 109 110 case types.Bool: 111 g.printf("switch %s {\ncase false:\n data = append(data, uint8(0))\n", n) 112 g.printf("default:\ndata = append(data, uint8(1))\n}\n") 113 114 case types.Uint: 115 g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n", n) 116 g.printf("data = append(data, buf[:8]...)\n") 117 118 case types.Uint8: 119 g.printf("data = append(data, byte(%s))\n", n) 120 121 case types.Uint16: 122 g.printf( 123 "binary.LittleEndian.PutUint16(buf[:2], uint16(%s))\n", 124 n, 125 ) 126 g.printf("data = append(data, buf[:2]...)\n") 127 128 case types.Uint32: 129 g.printf( 130 "binary.LittleEndian.PutUint32(buf[:4], uint32(%s))\n", 131 n, 132 ) 133 g.printf("data = append(data, buf[:4]...)\n") 134 135 case types.Uint64: 136 g.printf( 137 "binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n", 138 n, 139 ) 140 g.printf("data = append(data, buf[:8]...)\n") 141 142 case types.Int: 143 g.printf( 144 "binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n", 145 n, 146 ) 147 g.printf("data = append(data, buf[:8]...)\n") 148 149 case types.Int8: 150 g.printf("data = append(data, byte(%s))\n", n) 151 152 case types.Int16: 153 g.printf( 154 "binary.LittleEndian.PutUint16(buf[:2], uint16(%s))\n", 155 n, 156 ) 157 g.printf("data = append(data, buf[:2]...)\n") 158 159 case types.Int32: 160 g.printf( 161 "binary.LittleEndian.PutUint32(buf[:4], uint32(%s))\n", 162 n, 163 ) 164 g.printf("data = append(data, buf[:4]...)\n") 165 166 case types.Int64: 167 g.printf( 168 "binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n", 169 n, 170 ) 171 g.printf("data = append(data, buf[:8]...)\n") 172 173 case types.Float32: 174 g.imps["math"] = 1 175 g.printf( 176 "binary.LittleEndian.PutUint32(buf[:4], math.Float32bits(%s))\n", 177 n, 178 ) 179 g.printf("data = append(data, buf[:4]...)\n") 180 181 case types.Float64: 182 g.imps["math"] = 1 183 g.printf( 184 "binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(%s))\n", 185 n, 186 ) 187 g.printf("data = append(data, buf[:8]...)\n") 188 189 case types.Complex64: 190 g.imps["math"] = 1 191 g.printf( 192 "binary.LittleEndian.PutUint64(buf[:4], math.Float32bits(real(%s)))\n", 193 n, 194 ) 195 g.printf("data = append(data, buf[:4]...)\n") 196 g.printf( 197 "binary.LittleEndian.PutUint64(buf[:4], math.Float32bits(imag(%s)))\n", 198 n, 199 ) 200 g.printf("data = append(data, buf[:4]...)\n") 201 202 case types.Complex128: 203 g.imps["math"] = 1 204 g.printf( 205 "binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(real(%s)))\n", 206 n, 207 ) 208 g.printf("data = append(data, buf[:8]...)\n") 209 g.printf( 210 "binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(imag(%s)))\n", 211 n, 212 ) 213 g.printf("data = append(data, buf[:8]...)\n") 214 215 case types.String: 216 g.printf( 217 "binary.LittleEndian.PutUint64(buf[:8], uint64(len(%s)))\n", 218 n, 219 ) 220 g.printf("data = append(data, buf[:8]...)\n") 221 g.printf("data = append(data, []byte(%s)...)\n", n) 222 223 default: 224 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 225 } 226 227 case *types.Struct: 228 switch t.(type) { 229 case *types.Named: 230 g.printf("{\nsub, err := %s.MarshalBinary()\n", n) 231 g.printf("if err != nil {\nreturn nil, err\n}\n") 232 g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(len(sub)))\n") 233 g.printf("data = append(data, buf[:8]...)\n") 234 g.printf("data = append(data, sub...)\n") 235 g.printf("}\n") 236 default: 237 // un-named 238 for elem := range ut.Fields() { 239 g.genMarshalType(elem.Type(), n+"."+elem.Name()) 240 } 241 } 242 243 case *types.Array: 244 if isByteType(ut.Elem()) { 245 g.printf("data = append(data, %s[:]...)\n", n) 246 } else { 247 g.printf("for i := range %s {\n", n) 248 if _, ok := ut.Elem().(*types.Pointer); ok { 249 g.printf("o := %s[i]\n", n) 250 } else { 251 g.printf("o := &%s[i]\n", n) 252 } 253 g.genMarshalType(ut.Elem(), "o") 254 g.printf("}\n") 255 } 256 257 case *types.Slice: 258 g.printf( 259 "binary.LittleEndian.PutUint64(buf[:8], uint64(len(%s)))\n", 260 n, 261 ) 262 g.printf("data = append(data, buf[:8]...)\n") 263 if isByteType(ut.Elem()) { 264 g.printf("data = append(data, %s...)\n", n) 265 } else { 266 g.printf("for i := range %s {\n", n) 267 if _, ok := ut.Elem().(*types.Pointer); ok { 268 g.printf("o := %s[i]\n", n) 269 } else { 270 g.printf("o := &%s[i]\n", n) 271 } 272 g.genMarshalType(ut.Elem(), "o") 273 g.printf("}\n") 274 } 275 276 case *types.Pointer: 277 g.printf("{\n") 278 g.printf("v := *%s\n", n) 279 g.genMarshalType(ut.Elem(), "v") 280 g.printf("}\n") 281 282 case *types.Interface: 283 log.Fatalf("marshal interface not supported (type=%v)\n", t) 284 285 default: 286 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 287 } 288 } 289 290 func (g *Generator) genUnmarshal(t types.Type, typeName string) { 291 g.printf(`// UnmarshalBinary implements encoding.BinaryUnmarshaler 292 func (o *%[1]s) UnmarshalBinary(data []byte) (err error) { 293 `, 294 typeName, 295 ) 296 297 typ := t.Underlying().(*types.Struct) 298 for ft := range typ.Fields() { 299 g.genUnmarshalType(ft.Type(), "o."+ft.Name()) 300 } 301 302 g.printf("_ = data\n") 303 g.printf("return err\n}\n\n") 304 } 305 306 func (g *Generator) genUnmarshalType(t types.Type, n string) { 307 if types.Implements(t, binUn) || types.Implements(types.NewPointer(t), binUn) { 308 g.printf("{\n") 309 g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n") 310 g.printf("data = data[8:]\n") 311 g.printf("err = %s.UnmarshalBinary(data[:n])\n", n) 312 g.printf("if err != nil {\nreturn err\n}\n") 313 g.printf("data = data[n:]\n") 314 g.printf("}\n") 315 return 316 } 317 318 tn := types.TypeString(t, types.RelativeTo(g.pkg)) 319 ut := t.Underlying() 320 switch ut := ut.(type) { 321 case *types.Basic: 322 switch kind := ut.Kind(); kind { 323 324 case types.Bool: 325 g.printf("switch data[i] {\ncase 0:\n%s = false\n", n) 326 g.printf("default:\n%s = true\n}\n", n) 327 g.printf("data = data[1:]\n") 328 329 case types.Uint: 330 g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn) 331 g.printf("data = data[8:]\n") 332 333 case types.Uint8: 334 g.printf("%s = %s(data[0])\n", n, tn) 335 g.printf("data = data[1:]\n") 336 337 case types.Uint16: 338 g.printf("%s = %s(binary.LittleEndian.Uint16(data[:2]))\n", n, tn) 339 g.printf("data = data[2:]\n") 340 341 case types.Uint32: 342 g.printf("%s = %s(binary.LittleEndian.Uint32(data[:4]))\n", n, tn) 343 g.printf("data = data[4:]\n") 344 345 case types.Uint64: 346 g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn) 347 g.printf("data = data[8:]\n") 348 349 case types.Int: 350 g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn) 351 g.printf("data = data[8:]\n") 352 353 case types.Int8: 354 g.printf("%s = %s(data[0])\n", n, tn) 355 g.printf("data = data[1:]\n") 356 357 case types.Int16: 358 g.printf("%s = %s(binary.LittleEndian.Uint16(data[:2]))\n", n, tn) 359 g.printf("data = data[2:]\n") 360 361 case types.Int32: 362 g.printf("%s = %s(binary.LittleEndian.Uint32(data[:4]))\n", n, tn) 363 g.printf("data = data[4:]\n") 364 365 case types.Int64: 366 g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn) 367 g.printf("data = data[8:]\n") 368 369 case types.Float32: 370 g.imps["math"] = 1 371 g.printf("%s = %s(math.Float32frombits(binary.LittleEndian.Uint32(data[:4])))\n", n, tn) 372 g.printf("data = data[4:]\n") 373 374 case types.Float64: 375 g.imps["math"] = 1 376 g.printf("%s = %s(math.Float64frombits(binary.LittleEndian.Uint64(data[:8])))\n", n, tn) 377 g.printf("data = data[8:]\n") 378 379 case types.Complex64: 380 g.imps["math"] = 1 381 g.printf("%s = %s(complex(math.Float32frombits(binary.LittleEndian.Uint32(data[:4])), math.Float32frombits(binary.LittleEndian.Uint32(data[4:8]))))\n", n, tn) 382 g.printf("data = data[8:]\n") 383 384 case types.Complex128: 385 g.imps["math"] = 1 386 g.printf("%s = %s(complex(math.Float64frombits(binary.LittleEndian.Uint64(data[:8])), math.Float64frombits(binary.LittleEndian.Uint64(data[8:16]))))\n", n, tn) 387 g.printf("data = data[16:]\n") 388 389 case types.String: 390 g.printf("{\n") 391 g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n") 392 g.printf("data = data[8:]\n") 393 g.printf("%s = %s(data[:n])\n", n, tn) 394 g.printf("data = data[n:]\n") 395 g.printf("}\n") 396 397 default: 398 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 399 } 400 401 case *types.Struct: 402 switch t.(type) { 403 case *types.Named: 404 g.printf("{\n") 405 g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n") 406 g.printf("data = data[8:]\n") 407 g.printf("err = %s.UnmarshalBinary(data[:n])\n", n) 408 g.printf("if err != nil {\nreturn err\n}\n") 409 g.printf("data = data[n:]\n") 410 g.printf("}\n") 411 default: 412 // un-named. 413 for elem := range ut.Fields() { 414 g.genUnmarshalType(elem.Type(), n+"."+elem.Name()) 415 } 416 } 417 418 case *types.Array: 419 if isByteType(ut.Elem()) { 420 g.printf("copy(%s[:], data[:n])\n", n) 421 g.printf("data = data[n:]\n") 422 } else { 423 g.printf("for i := range %s {\n", n) 424 nn := n + "[i]" 425 if pt, ok := ut.Elem().(*types.Pointer); ok { 426 g.printf("var oi %s\n", qualTypeName(pt.Elem(), g.pkg)) 427 nn = "oi" 428 } 429 if _, ok := ut.Elem().Underlying().(*types.Struct); ok { 430 g.printf("oi := &%s[i]\n", n) 431 nn = "oi" 432 } 433 g.genUnmarshalType(ut.Elem(), nn) 434 if _, ok := ut.Elem().(*types.Pointer); ok { 435 g.printf("%s[i] = oi\n", n) 436 } 437 g.printf("}\n") 438 } 439 440 case *types.Slice: 441 g.printf("{\n") 442 g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n") 443 g.printf("%[1]s = make([]%[2]s, n)\n", n, qualTypeName(ut.Elem(), g.pkg)) 444 g.printf("data = data[8:]\n") 445 if isByteType(ut.Elem()) { 446 g.printf("%[1]s = append(%[1]s, data[:n]...)\n", n) 447 g.printf("data = data[n:]\n") 448 } else { 449 g.printf("for i := range %s {\n", n) 450 nn := n + "[i]" 451 if pt, ok := ut.Elem().(*types.Pointer); ok { 452 g.printf("var oi %s\n", qualTypeName(pt.Elem(), g.pkg)) 453 nn = "oi" 454 } 455 if _, ok := ut.Elem().Underlying().(*types.Struct); ok { 456 g.printf("oi := &%s[i]\n", n) 457 nn = "oi" 458 } 459 g.genUnmarshalType(ut.Elem(), nn) 460 if _, ok := ut.Elem().(*types.Pointer); ok { 461 g.printf("%s[i] = oi\n", n) 462 } 463 g.printf("}\n") 464 } 465 g.printf("}\n") 466 467 case *types.Pointer: 468 g.printf("{\n") 469 elt := ut.Elem() 470 g.printf("var v %s\n", qualTypeName(elt, g.pkg)) 471 g.genUnmarshalType(elt, "v") 472 g.printf("%s = &v\n\n", n) 473 g.printf("}\n") 474 475 case *types.Interface: 476 log.Fatalf("marshal interface not supported (type=%v)\n", t) 477 478 default: 479 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 480 } 481 482 } 483 484 func isByteType(t types.Type) bool { 485 b, ok := t.Underlying().(*types.Basic) 486 if !ok { 487 return false 488 } 489 return b.Kind() == types.Byte 490 } 491 492 func qualTypeName(t types.Type, pkg *types.Package) string { 493 n := types.TypeString(t, types.RelativeTo(pkg)) 494 i := strings.LastIndex(n, "/") 495 if i < 0 { 496 return n 497 } 498 return string(n[i+1:]) 499 } 500 501 func (g *Generator) Format() ([]byte, error) { 502 buf := new(bytes.Buffer) 503 504 // See standard at https://golang.org/s/generatedcode 505 buf.WriteString(fmt.Sprintf(`// Code generated by %[1]s; DO NOT EDIT. 506 507 package %[2]s 508 509 import ( 510 "encoding/binary" 511 `, 512 "brio-gen", 513 g.pkg.Name(), 514 )) 515 516 for k := range g.imps { 517 fmt.Fprintf(buf, "%q\n", k) 518 } 519 fmt.Fprintf(buf, ")\n\n") 520 521 buf.Write(g.buf.Bytes()) 522 523 src, err := format.Source(buf.Bytes()) 524 if err != nil { 525 log.Printf("=== error ===\n%s\n", buf.Bytes()) 526 } 527 return src, err 528 } 529 530 func importPkg(p string) (*types.Package, error) { 531 cfg := &packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedTypesSizes | packages.NeedDeps} 532 pkgs, err := packages.Load(cfg, p) 533 if err != nil { 534 return nil, fmt.Errorf("could not load package %q: %w", p, err) 535 } 536 537 return pkgs[0].Types, nil 538 } 539 540 func init() { 541 pkg, err := importPkg("encoding") 542 if err != nil { 543 log.Fatalf("error finding package \"encoding\": %v\n", err) 544 } 545 546 o := pkg.Scope().Lookup("BinaryMarshaler") 547 if o == nil { 548 log.Fatalf("could not find interface encoding.BinaryMarshaler\n") 549 } 550 binMa = o.(*types.TypeName).Type().Underlying().(*types.Interface) 551 552 o = pkg.Scope().Lookup("BinaryUnmarshaler") 553 if o == nil { 554 log.Fatalf("could not find interface encoding.BinaryUnmarshaler\n") 555 } 556 binUn = o.(*types.TypeName).Type().Underlying().(*types.Interface) 557 }