go-hep.org/x/hep@v0.38.1/brio/cmd/brio-gen/internal/gen/gen.go (about)

     1  // Copyright ©2016 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gen // import "go-hep.org/x/hep/brio/cmd/brio-gen/internal/gen"
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"go/format"
    11  	"go/types"
    12  	"log"
    13  	"strings"
    14  
    15  	"golang.org/x/tools/go/packages"
    16  )
    17  
    18  var (
    19  	binMa *types.Interface // encoding.BinaryMarshaler
    20  	binUn *types.Interface // encoding.BinaryUnmarshaler
    21  )
    22  
    23  // Generator holds the state of the generation.
    24  type Generator struct {
    25  	buf *bytes.Buffer
    26  	pkg *types.Package
    27  
    28  	// set of imported packages.
    29  	// usually: "encoding/binary", "math"
    30  	imps map[string]int
    31  
    32  	Verbose bool // enable verbose mode
    33  }
    34  
    35  // NewGenerator returns a new code generator for package p,
    36  // where p is the package's import path.
    37  func NewGenerator(p string) (*Generator, error) {
    38  	pkg, err := importPkg(p)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  
    43  	return &Generator{
    44  		buf:  new(bytes.Buffer),
    45  		pkg:  pkg,
    46  		imps: map[string]int{"encoding/binary": 1},
    47  	}, nil
    48  }
    49  
    50  func (g *Generator) printf(format string, args ...any) {
    51  	fmt.Fprintf(g.buf, format, args...)
    52  }
    53  
    54  func (g *Generator) Generate(typeName string) {
    55  	scope := g.pkg.Scope()
    56  	obj := scope.Lookup(typeName)
    57  	if obj == nil {
    58  		log.Fatalf("no such type %q in package %q\n", typeName, g.pkg.Path()+"/"+g.pkg.Name())
    59  	}
    60  
    61  	tn, ok := obj.(*types.TypeName)
    62  	if !ok {
    63  		log.Fatalf("%q is not a type (%v)\n", typeName, obj)
    64  	}
    65  
    66  	typ, ok := tn.Type().Underlying().(*types.Struct)
    67  	if !ok {
    68  		log.Fatalf("%q is not a named struct (%v)\n", typeName, tn)
    69  	}
    70  	if g.Verbose {
    71  		log.Printf("typ: %+v\n", typ)
    72  	}
    73  
    74  	g.genMarshal(typ, typeName)
    75  	g.genUnmarshal(typ, typeName)
    76  }
    77  
    78  func (g *Generator) genMarshal(t types.Type, typeName string) {
    79  	g.printf(`// MarshalBinary implements encoding.BinaryMarshaler
    80  func (o *%[1]s) MarshalBinary() (data []byte, err error) {
    81  	var buf [8]byte
    82  `,
    83  		typeName,
    84  	)
    85  
    86  	typ := t.Underlying().(*types.Struct)
    87  	for i := range typ.NumFields() {
    88  		ft := typ.Field(i)
    89  		g.genMarshalType(ft.Type(), "o."+ft.Name())
    90  	}
    91  
    92  	g.printf("return data, err\n}\n\n")
    93  }
    94  
    95  func (g *Generator) genMarshalType(t types.Type, n string) {
    96  	if types.Implements(t, binMa) || types.Implements(types.NewPointer(t), binMa) {
    97  		g.printf("{\nsub, err := %s.MarshalBinary()\n", n)
    98  		g.printf("if err != nil {\nreturn nil, err\n}\n")
    99  		g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(len(sub)))\n")
   100  		g.printf("data = append(data, buf[:8]...)\n")
   101  		g.printf("data = append(data, sub...)\n")
   102  		g.printf("}\n")
   103  		return
   104  	}
   105  
   106  	ut := t.Underlying()
   107  	switch ut := ut.(type) {
   108  	case *types.Basic:
   109  		switch kind := ut.Kind(); kind {
   110  
   111  		case types.Bool:
   112  			g.printf("switch %s {\ncase false:\n data = append(data, uint8(0))\n", n)
   113  			g.printf("default:\ndata = append(data, uint8(1))\n}\n")
   114  
   115  		case types.Uint:
   116  			g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n", n)
   117  			g.printf("data = append(data, buf[:8]...)\n")
   118  
   119  		case types.Uint8:
   120  			g.printf("data = append(data, byte(%s))\n", n)
   121  
   122  		case types.Uint16:
   123  			g.printf(
   124  				"binary.LittleEndian.PutUint16(buf[:2], uint16(%s))\n",
   125  				n,
   126  			)
   127  			g.printf("data = append(data, buf[:2]...)\n")
   128  
   129  		case types.Uint32:
   130  			g.printf(
   131  				"binary.LittleEndian.PutUint32(buf[:4], uint32(%s))\n",
   132  				n,
   133  			)
   134  			g.printf("data = append(data, buf[:4]...)\n")
   135  
   136  		case types.Uint64:
   137  			g.printf(
   138  				"binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n",
   139  				n,
   140  			)
   141  			g.printf("data = append(data, buf[:8]...)\n")
   142  
   143  		case types.Int:
   144  			g.printf(
   145  				"binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n",
   146  				n,
   147  			)
   148  			g.printf("data = append(data, buf[:8]...)\n")
   149  
   150  		case types.Int8:
   151  			g.printf("data = append(data, byte(%s))\n", n)
   152  
   153  		case types.Int16:
   154  			g.printf(
   155  				"binary.LittleEndian.PutUint16(buf[:2], uint16(%s))\n",
   156  				n,
   157  			)
   158  			g.printf("data = append(data, buf[:2]...)\n")
   159  
   160  		case types.Int32:
   161  			g.printf(
   162  				"binary.LittleEndian.PutUint32(buf[:4], uint32(%s))\n",
   163  				n,
   164  			)
   165  			g.printf("data = append(data, buf[:4]...)\n")
   166  
   167  		case types.Int64:
   168  			g.printf(
   169  				"binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n",
   170  				n,
   171  			)
   172  			g.printf("data = append(data, buf[:8]...)\n")
   173  
   174  		case types.Float32:
   175  			g.imps["math"] = 1
   176  			g.printf(
   177  				"binary.LittleEndian.PutUint32(buf[:4], math.Float32bits(%s))\n",
   178  				n,
   179  			)
   180  			g.printf("data = append(data, buf[:4]...)\n")
   181  
   182  		case types.Float64:
   183  			g.imps["math"] = 1
   184  			g.printf(
   185  				"binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(%s))\n",
   186  				n,
   187  			)
   188  			g.printf("data = append(data, buf[:8]...)\n")
   189  
   190  		case types.Complex64:
   191  			g.imps["math"] = 1
   192  			g.printf(
   193  				"binary.LittleEndian.PutUint64(buf[:4], math.Float32bits(real(%s)))\n",
   194  				n,
   195  			)
   196  			g.printf("data = append(data, buf[:4]...)\n")
   197  			g.printf(
   198  				"binary.LittleEndian.PutUint64(buf[:4], math.Float32bits(imag(%s)))\n",
   199  				n,
   200  			)
   201  			g.printf("data = append(data, buf[:4]...)\n")
   202  
   203  		case types.Complex128:
   204  			g.imps["math"] = 1
   205  			g.printf(
   206  				"binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(real(%s)))\n",
   207  				n,
   208  			)
   209  			g.printf("data = append(data, buf[:8]...)\n")
   210  			g.printf(
   211  				"binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(imag(%s)))\n",
   212  				n,
   213  			)
   214  			g.printf("data = append(data, buf[:8]...)\n")
   215  
   216  		case types.String:
   217  			g.printf(
   218  				"binary.LittleEndian.PutUint64(buf[:8], uint64(len(%s)))\n",
   219  				n,
   220  			)
   221  			g.printf("data = append(data, buf[:8]...)\n")
   222  			g.printf("data = append(data, []byte(%s)...)\n", n)
   223  
   224  		default:
   225  			log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   226  		}
   227  
   228  	case *types.Struct:
   229  		switch t.(type) {
   230  		case *types.Named:
   231  			g.printf("{\nsub, err := %s.MarshalBinary()\n", n)
   232  			g.printf("if err != nil {\nreturn nil, err\n}\n")
   233  			g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(len(sub)))\n")
   234  			g.printf("data = append(data, buf[:8]...)\n")
   235  			g.printf("data = append(data, sub...)\n")
   236  			g.printf("}\n")
   237  		default:
   238  			// un-named
   239  			for i := range ut.NumFields() {
   240  				elem := ut.Field(i)
   241  				g.genMarshalType(elem.Type(), n+"."+elem.Name())
   242  			}
   243  		}
   244  
   245  	case *types.Array:
   246  		if isByteType(ut.Elem()) {
   247  			g.printf("data = append(data, %s[:]...)\n", n)
   248  		} else {
   249  			g.printf("for i := range %s {\n", n)
   250  			if _, ok := ut.Elem().(*types.Pointer); ok {
   251  				g.printf("o := %s[i]\n", n)
   252  			} else {
   253  				g.printf("o := &%s[i]\n", n)
   254  			}
   255  			g.genMarshalType(ut.Elem(), "o")
   256  			g.printf("}\n")
   257  		}
   258  
   259  	case *types.Slice:
   260  		g.printf(
   261  			"binary.LittleEndian.PutUint64(buf[:8], uint64(len(%s)))\n",
   262  			n,
   263  		)
   264  		g.printf("data = append(data, buf[:8]...)\n")
   265  		if isByteType(ut.Elem()) {
   266  			g.printf("data = append(data, %s...)\n", n)
   267  		} else {
   268  			g.printf("for i := range %s {\n", n)
   269  			if _, ok := ut.Elem().(*types.Pointer); ok {
   270  				g.printf("o := %s[i]\n", n)
   271  			} else {
   272  				g.printf("o := &%s[i]\n", n)
   273  			}
   274  			g.genMarshalType(ut.Elem(), "o")
   275  			g.printf("}\n")
   276  		}
   277  
   278  	case *types.Pointer:
   279  		g.printf("{\n")
   280  		g.printf("v := *%s\n", n)
   281  		g.genMarshalType(ut.Elem(), "v")
   282  		g.printf("}\n")
   283  
   284  	case *types.Interface:
   285  		log.Fatalf("marshal interface not supported (type=%v)\n", t)
   286  
   287  	default:
   288  		log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   289  	}
   290  }
   291  
   292  func (g *Generator) genUnmarshal(t types.Type, typeName string) {
   293  	g.printf(`// UnmarshalBinary implements encoding.BinaryUnmarshaler
   294  func (o *%[1]s) UnmarshalBinary(data []byte) (err error) {
   295  `,
   296  		typeName,
   297  	)
   298  
   299  	typ := t.Underlying().(*types.Struct)
   300  	for i := range typ.NumFields() {
   301  		ft := typ.Field(i)
   302  		g.genUnmarshalType(ft.Type(), "o."+ft.Name())
   303  	}
   304  
   305  	g.printf("_ = data\n")
   306  	g.printf("return err\n}\n\n")
   307  }
   308  
   309  func (g *Generator) genUnmarshalType(t types.Type, n string) {
   310  	if types.Implements(t, binUn) || types.Implements(types.NewPointer(t), binUn) {
   311  		g.printf("{\n")
   312  		g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n")
   313  		g.printf("data = data[8:]\n")
   314  		g.printf("err = %s.UnmarshalBinary(data[:n])\n", n)
   315  		g.printf("if err != nil {\nreturn err\n}\n")
   316  		g.printf("data = data[n:]\n")
   317  		g.printf("}\n")
   318  		return
   319  	}
   320  
   321  	tn := types.TypeString(t, types.RelativeTo(g.pkg))
   322  	ut := t.Underlying()
   323  	switch ut := ut.(type) {
   324  	case *types.Basic:
   325  		switch kind := ut.Kind(); kind {
   326  
   327  		case types.Bool:
   328  			g.printf("switch data[i] {\ncase 0:\n%s = false\n", n)
   329  			g.printf("default:\n%s = true\n}\n", n)
   330  			g.printf("data = data[1:]\n")
   331  
   332  		case types.Uint:
   333  			g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn)
   334  			g.printf("data = data[8:]\n")
   335  
   336  		case types.Uint8:
   337  			g.printf("%s = %s(data[0])\n", n, tn)
   338  			g.printf("data = data[1:]\n")
   339  
   340  		case types.Uint16:
   341  			g.printf("%s = %s(binary.LittleEndian.Uint16(data[:2]))\n", n, tn)
   342  			g.printf("data = data[2:]\n")
   343  
   344  		case types.Uint32:
   345  			g.printf("%s = %s(binary.LittleEndian.Uint32(data[:4]))\n", n, tn)
   346  			g.printf("data = data[4:]\n")
   347  
   348  		case types.Uint64:
   349  			g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn)
   350  			g.printf("data = data[8:]\n")
   351  
   352  		case types.Int:
   353  			g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn)
   354  			g.printf("data = data[8:]\n")
   355  
   356  		case types.Int8:
   357  			g.printf("%s = %s(data[0])\n", n, tn)
   358  			g.printf("data = data[1:]\n")
   359  
   360  		case types.Int16:
   361  			g.printf("%s = %s(binary.LittleEndian.Uint16(data[:2]))\n", n, tn)
   362  			g.printf("data = data[2:]\n")
   363  
   364  		case types.Int32:
   365  			g.printf("%s = %s(binary.LittleEndian.Uint32(data[:4]))\n", n, tn)
   366  			g.printf("data = data[4:]\n")
   367  
   368  		case types.Int64:
   369  			g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn)
   370  			g.printf("data = data[8:]\n")
   371  
   372  		case types.Float32:
   373  			g.imps["math"] = 1
   374  			g.printf("%s = %s(math.Float32frombits(binary.LittleEndian.Uint32(data[:4])))\n", n, tn)
   375  			g.printf("data = data[4:]\n")
   376  
   377  		case types.Float64:
   378  			g.imps["math"] = 1
   379  			g.printf("%s = %s(math.Float64frombits(binary.LittleEndian.Uint64(data[:8])))\n", n, tn)
   380  			g.printf("data = data[8:]\n")
   381  
   382  		case types.Complex64:
   383  			g.imps["math"] = 1
   384  			g.printf("%s = %s(complex(math.Float32frombits(binary.LittleEndian.Uint32(data[:4])), math.Float32frombits(binary.LittleEndian.Uint32(data[4:8]))))\n", n, tn)
   385  			g.printf("data = data[8:]\n")
   386  
   387  		case types.Complex128:
   388  			g.imps["math"] = 1
   389  			g.printf("%s = %s(complex(math.Float64frombits(binary.LittleEndian.Uint64(data[:8])), math.Float64frombits(binary.LittleEndian.Uint64(data[8:16]))))\n", n, tn)
   390  			g.printf("data = data[16:]\n")
   391  
   392  		case types.String:
   393  			g.printf("{\n")
   394  			g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n")
   395  			g.printf("data = data[8:]\n")
   396  			g.printf("%s = %s(data[:n])\n", n, tn)
   397  			g.printf("data = data[n:]\n")
   398  			g.printf("}\n")
   399  
   400  		default:
   401  			log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   402  		}
   403  
   404  	case *types.Struct:
   405  		switch t.(type) {
   406  		case *types.Named:
   407  			g.printf("{\n")
   408  			g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n")
   409  			g.printf("data = data[8:]\n")
   410  			g.printf("err = %s.UnmarshalBinary(data[:n])\n", n)
   411  			g.printf("if err != nil {\nreturn err\n}\n")
   412  			g.printf("data = data[n:]\n")
   413  			g.printf("}\n")
   414  		default:
   415  			// un-named.
   416  			for i := range ut.NumFields() {
   417  				elem := ut.Field(i)
   418  				g.genUnmarshalType(elem.Type(), n+"."+elem.Name())
   419  			}
   420  		}
   421  
   422  	case *types.Array:
   423  		if isByteType(ut.Elem()) {
   424  			g.printf("copy(%s[:], data[:n])\n", n)
   425  			g.printf("data = data[n:]\n")
   426  		} else {
   427  			g.printf("for i := range %s {\n", n)
   428  			nn := n + "[i]"
   429  			if pt, ok := ut.Elem().(*types.Pointer); ok {
   430  				g.printf("var oi %s\n", qualTypeName(pt.Elem(), g.pkg))
   431  				nn = "oi"
   432  			}
   433  			if _, ok := ut.Elem().Underlying().(*types.Struct); ok {
   434  				g.printf("oi := &%s[i]\n", n)
   435  				nn = "oi"
   436  			}
   437  			g.genUnmarshalType(ut.Elem(), nn)
   438  			if _, ok := ut.Elem().(*types.Pointer); ok {
   439  				g.printf("%s[i] = oi\n", n)
   440  			}
   441  			g.printf("}\n")
   442  		}
   443  
   444  	case *types.Slice:
   445  		g.printf("{\n")
   446  		g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n")
   447  		g.printf("%[1]s = make([]%[2]s, n)\n", n, qualTypeName(ut.Elem(), g.pkg))
   448  		g.printf("data = data[8:]\n")
   449  		if isByteType(ut.Elem()) {
   450  			g.printf("%[1]s = append(%[1]s, data[:n]...)\n", n)
   451  			g.printf("data = data[n:]\n")
   452  		} else {
   453  			g.printf("for i := range %s {\n", n)
   454  			nn := n + "[i]"
   455  			if pt, ok := ut.Elem().(*types.Pointer); ok {
   456  				g.printf("var oi %s\n", qualTypeName(pt.Elem(), g.pkg))
   457  				nn = "oi"
   458  			}
   459  			if _, ok := ut.Elem().Underlying().(*types.Struct); ok {
   460  				g.printf("oi := &%s[i]\n", n)
   461  				nn = "oi"
   462  			}
   463  			g.genUnmarshalType(ut.Elem(), nn)
   464  			if _, ok := ut.Elem().(*types.Pointer); ok {
   465  				g.printf("%s[i] = oi\n", n)
   466  			}
   467  			g.printf("}\n")
   468  		}
   469  		g.printf("}\n")
   470  
   471  	case *types.Pointer:
   472  		g.printf("{\n")
   473  		elt := ut.Elem()
   474  		g.printf("var v %s\n", qualTypeName(elt, g.pkg))
   475  		g.genUnmarshalType(elt, "v")
   476  		g.printf("%s = &v\n\n", n)
   477  		g.printf("}\n")
   478  
   479  	case *types.Interface:
   480  		log.Fatalf("marshal interface not supported (type=%v)\n", t)
   481  
   482  	default:
   483  		log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   484  	}
   485  
   486  }
   487  
   488  func isByteType(t types.Type) bool {
   489  	b, ok := t.Underlying().(*types.Basic)
   490  	if !ok {
   491  		return false
   492  	}
   493  	return b.Kind() == types.Byte
   494  }
   495  
   496  func qualTypeName(t types.Type, pkg *types.Package) string {
   497  	n := types.TypeString(t, types.RelativeTo(pkg))
   498  	i := strings.LastIndex(n, "/")
   499  	if i < 0 {
   500  		return n
   501  	}
   502  	return string(n[i+1:])
   503  }
   504  
   505  func (g *Generator) Format() ([]byte, error) {
   506  	buf := new(bytes.Buffer)
   507  
   508  	// See standard at https://golang.org/s/generatedcode
   509  	buf.WriteString(fmt.Sprintf(`// Code generated by %[1]s; DO NOT EDIT.
   510  
   511  package %[2]s
   512  
   513  import (
   514  	"encoding/binary"
   515  `,
   516  		"brio-gen",
   517  		g.pkg.Name(),
   518  	))
   519  
   520  	for k := range g.imps {
   521  		fmt.Fprintf(buf, "%q\n", k)
   522  	}
   523  	fmt.Fprintf(buf, ")\n\n")
   524  
   525  	buf.Write(g.buf.Bytes())
   526  
   527  	src, err := format.Source(buf.Bytes())
   528  	if err != nil {
   529  		log.Printf("=== error ===\n%s\n", buf.Bytes())
   530  	}
   531  	return src, err
   532  }
   533  
   534  func importPkg(p string) (*types.Package, error) {
   535  	cfg := &packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedTypesSizes | packages.NeedDeps}
   536  	pkgs, err := packages.Load(cfg, p)
   537  	if err != nil {
   538  		return nil, fmt.Errorf("could not load package %q: %w", p, err)
   539  	}
   540  
   541  	return pkgs[0].Types, nil
   542  }
   543  
   544  func init() {
   545  	pkg, err := importPkg("encoding")
   546  	if err != nil {
   547  		log.Fatalf("error finding package \"encoding\": %v\n", err)
   548  	}
   549  
   550  	o := pkg.Scope().Lookup("BinaryMarshaler")
   551  	if o == nil {
   552  		log.Fatalf("could not find interface encoding.BinaryMarshaler\n")
   553  	}
   554  	binMa = o.(*types.TypeName).Type().Underlying().(*types.Interface)
   555  
   556  	o = pkg.Scope().Lookup("BinaryUnmarshaler")
   557  	if o == nil {
   558  		log.Fatalf("could not find interface encoding.BinaryUnmarshaler\n")
   559  	}
   560  	binUn = o.(*types.TypeName).Type().Underlying().(*types.Interface)
   561  }