go-hep.org/x/hep@v0.40.0/brio/cmd/brio-gen/internal/gen/gen.go (about)

     1  // Copyright ©2016 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gen // import "go-hep.org/x/hep/brio/cmd/brio-gen/internal/gen"
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"go/format"
    11  	"go/types"
    12  	"log"
    13  	"strings"
    14  
    15  	"golang.org/x/tools/go/packages"
    16  )
    17  
    18  var (
    19  	binMa *types.Interface // encoding.BinaryMarshaler
    20  	binUn *types.Interface // encoding.BinaryUnmarshaler
    21  )
    22  
    23  // Generator holds the state of the generation.
    24  type Generator struct {
    25  	buf *bytes.Buffer
    26  	pkg *types.Package
    27  
    28  	// set of imported packages.
    29  	// usually: "encoding/binary", "math"
    30  	imps map[string]int
    31  
    32  	Verbose bool // enable verbose mode
    33  }
    34  
    35  // NewGenerator returns a new code generator for package p,
    36  // where p is the package's import path.
    37  func NewGenerator(p string) (*Generator, error) {
    38  	pkg, err := importPkg(p)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  
    43  	return &Generator{
    44  		buf:  new(bytes.Buffer),
    45  		pkg:  pkg,
    46  		imps: map[string]int{"encoding/binary": 1},
    47  	}, nil
    48  }
    49  
    50  func (g *Generator) printf(format string, args ...any) {
    51  	fmt.Fprintf(g.buf, format, args...)
    52  }
    53  
    54  func (g *Generator) Generate(typeName string) {
    55  	scope := g.pkg.Scope()
    56  	obj := scope.Lookup(typeName)
    57  	if obj == nil {
    58  		log.Fatalf("no such type %q in package %q\n", typeName, g.pkg.Path()+"/"+g.pkg.Name())
    59  	}
    60  
    61  	tn, ok := obj.(*types.TypeName)
    62  	if !ok {
    63  		log.Fatalf("%q is not a type (%v)\n", typeName, obj)
    64  	}
    65  
    66  	typ, ok := tn.Type().Underlying().(*types.Struct)
    67  	if !ok {
    68  		log.Fatalf("%q is not a named struct (%v)\n", typeName, tn)
    69  	}
    70  	if g.Verbose {
    71  		log.Printf("typ: %+v\n", typ)
    72  	}
    73  
    74  	g.genMarshal(typ, typeName)
    75  	g.genUnmarshal(typ, typeName)
    76  }
    77  
    78  func (g *Generator) genMarshal(t types.Type, typeName string) {
    79  	g.printf(`// MarshalBinary implements encoding.BinaryMarshaler
    80  func (o *%[1]s) MarshalBinary() (data []byte, err error) {
    81  	var buf [8]byte
    82  `,
    83  		typeName,
    84  	)
    85  
    86  	typ := t.Underlying().(*types.Struct)
    87  	for ft := range typ.Fields() {
    88  		g.genMarshalType(ft.Type(), "o."+ft.Name())
    89  	}
    90  
    91  	g.printf("return data, err\n}\n\n")
    92  }
    93  
    94  func (g *Generator) genMarshalType(t types.Type, n string) {
    95  	if types.Implements(t, binMa) || types.Implements(types.NewPointer(t), binMa) {
    96  		g.printf("{\nsub, err := %s.MarshalBinary()\n", n)
    97  		g.printf("if err != nil {\nreturn nil, err\n}\n")
    98  		g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(len(sub)))\n")
    99  		g.printf("data = append(data, buf[:8]...)\n")
   100  		g.printf("data = append(data, sub...)\n")
   101  		g.printf("}\n")
   102  		return
   103  	}
   104  
   105  	ut := t.Underlying()
   106  	switch ut := ut.(type) {
   107  	case *types.Basic:
   108  		switch kind := ut.Kind(); kind {
   109  
   110  		case types.Bool:
   111  			g.printf("switch %s {\ncase false:\n data = append(data, uint8(0))\n", n)
   112  			g.printf("default:\ndata = append(data, uint8(1))\n}\n")
   113  
   114  		case types.Uint:
   115  			g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n", n)
   116  			g.printf("data = append(data, buf[:8]...)\n")
   117  
   118  		case types.Uint8:
   119  			g.printf("data = append(data, byte(%s))\n", n)
   120  
   121  		case types.Uint16:
   122  			g.printf(
   123  				"binary.LittleEndian.PutUint16(buf[:2], uint16(%s))\n",
   124  				n,
   125  			)
   126  			g.printf("data = append(data, buf[:2]...)\n")
   127  
   128  		case types.Uint32:
   129  			g.printf(
   130  				"binary.LittleEndian.PutUint32(buf[:4], uint32(%s))\n",
   131  				n,
   132  			)
   133  			g.printf("data = append(data, buf[:4]...)\n")
   134  
   135  		case types.Uint64:
   136  			g.printf(
   137  				"binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n",
   138  				n,
   139  			)
   140  			g.printf("data = append(data, buf[:8]...)\n")
   141  
   142  		case types.Int:
   143  			g.printf(
   144  				"binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n",
   145  				n,
   146  			)
   147  			g.printf("data = append(data, buf[:8]...)\n")
   148  
   149  		case types.Int8:
   150  			g.printf("data = append(data, byte(%s))\n", n)
   151  
   152  		case types.Int16:
   153  			g.printf(
   154  				"binary.LittleEndian.PutUint16(buf[:2], uint16(%s))\n",
   155  				n,
   156  			)
   157  			g.printf("data = append(data, buf[:2]...)\n")
   158  
   159  		case types.Int32:
   160  			g.printf(
   161  				"binary.LittleEndian.PutUint32(buf[:4], uint32(%s))\n",
   162  				n,
   163  			)
   164  			g.printf("data = append(data, buf[:4]...)\n")
   165  
   166  		case types.Int64:
   167  			g.printf(
   168  				"binary.LittleEndian.PutUint64(buf[:8], uint64(%s))\n",
   169  				n,
   170  			)
   171  			g.printf("data = append(data, buf[:8]...)\n")
   172  
   173  		case types.Float32:
   174  			g.imps["math"] = 1
   175  			g.printf(
   176  				"binary.LittleEndian.PutUint32(buf[:4], math.Float32bits(%s))\n",
   177  				n,
   178  			)
   179  			g.printf("data = append(data, buf[:4]...)\n")
   180  
   181  		case types.Float64:
   182  			g.imps["math"] = 1
   183  			g.printf(
   184  				"binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(%s))\n",
   185  				n,
   186  			)
   187  			g.printf("data = append(data, buf[:8]...)\n")
   188  
   189  		case types.Complex64:
   190  			g.imps["math"] = 1
   191  			g.printf(
   192  				"binary.LittleEndian.PutUint64(buf[:4], math.Float32bits(real(%s)))\n",
   193  				n,
   194  			)
   195  			g.printf("data = append(data, buf[:4]...)\n")
   196  			g.printf(
   197  				"binary.LittleEndian.PutUint64(buf[:4], math.Float32bits(imag(%s)))\n",
   198  				n,
   199  			)
   200  			g.printf("data = append(data, buf[:4]...)\n")
   201  
   202  		case types.Complex128:
   203  			g.imps["math"] = 1
   204  			g.printf(
   205  				"binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(real(%s)))\n",
   206  				n,
   207  			)
   208  			g.printf("data = append(data, buf[:8]...)\n")
   209  			g.printf(
   210  				"binary.LittleEndian.PutUint64(buf[:8], math.Float64bits(imag(%s)))\n",
   211  				n,
   212  			)
   213  			g.printf("data = append(data, buf[:8]...)\n")
   214  
   215  		case types.String:
   216  			g.printf(
   217  				"binary.LittleEndian.PutUint64(buf[:8], uint64(len(%s)))\n",
   218  				n,
   219  			)
   220  			g.printf("data = append(data, buf[:8]...)\n")
   221  			g.printf("data = append(data, []byte(%s)...)\n", n)
   222  
   223  		default:
   224  			log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   225  		}
   226  
   227  	case *types.Struct:
   228  		switch t.(type) {
   229  		case *types.Named:
   230  			g.printf("{\nsub, err := %s.MarshalBinary()\n", n)
   231  			g.printf("if err != nil {\nreturn nil, err\n}\n")
   232  			g.printf("binary.LittleEndian.PutUint64(buf[:8], uint64(len(sub)))\n")
   233  			g.printf("data = append(data, buf[:8]...)\n")
   234  			g.printf("data = append(data, sub...)\n")
   235  			g.printf("}\n")
   236  		default:
   237  			// un-named
   238  			for elem := range ut.Fields() {
   239  				g.genMarshalType(elem.Type(), n+"."+elem.Name())
   240  			}
   241  		}
   242  
   243  	case *types.Array:
   244  		if isByteType(ut.Elem()) {
   245  			g.printf("data = append(data, %s[:]...)\n", n)
   246  		} else {
   247  			g.printf("for i := range %s {\n", n)
   248  			if _, ok := ut.Elem().(*types.Pointer); ok {
   249  				g.printf("o := %s[i]\n", n)
   250  			} else {
   251  				g.printf("o := &%s[i]\n", n)
   252  			}
   253  			g.genMarshalType(ut.Elem(), "o")
   254  			g.printf("}\n")
   255  		}
   256  
   257  	case *types.Slice:
   258  		g.printf(
   259  			"binary.LittleEndian.PutUint64(buf[:8], uint64(len(%s)))\n",
   260  			n,
   261  		)
   262  		g.printf("data = append(data, buf[:8]...)\n")
   263  		if isByteType(ut.Elem()) {
   264  			g.printf("data = append(data, %s...)\n", n)
   265  		} else {
   266  			g.printf("for i := range %s {\n", n)
   267  			if _, ok := ut.Elem().(*types.Pointer); ok {
   268  				g.printf("o := %s[i]\n", n)
   269  			} else {
   270  				g.printf("o := &%s[i]\n", n)
   271  			}
   272  			g.genMarshalType(ut.Elem(), "o")
   273  			g.printf("}\n")
   274  		}
   275  
   276  	case *types.Pointer:
   277  		g.printf("{\n")
   278  		g.printf("v := *%s\n", n)
   279  		g.genMarshalType(ut.Elem(), "v")
   280  		g.printf("}\n")
   281  
   282  	case *types.Interface:
   283  		log.Fatalf("marshal interface not supported (type=%v)\n", t)
   284  
   285  	default:
   286  		log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   287  	}
   288  }
   289  
   290  func (g *Generator) genUnmarshal(t types.Type, typeName string) {
   291  	g.printf(`// UnmarshalBinary implements encoding.BinaryUnmarshaler
   292  func (o *%[1]s) UnmarshalBinary(data []byte) (err error) {
   293  `,
   294  		typeName,
   295  	)
   296  
   297  	typ := t.Underlying().(*types.Struct)
   298  	for ft := range typ.Fields() {
   299  		g.genUnmarshalType(ft.Type(), "o."+ft.Name())
   300  	}
   301  
   302  	g.printf("_ = data\n")
   303  	g.printf("return err\n}\n\n")
   304  }
   305  
   306  func (g *Generator) genUnmarshalType(t types.Type, n string) {
   307  	if types.Implements(t, binUn) || types.Implements(types.NewPointer(t), binUn) {
   308  		g.printf("{\n")
   309  		g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n")
   310  		g.printf("data = data[8:]\n")
   311  		g.printf("err = %s.UnmarshalBinary(data[:n])\n", n)
   312  		g.printf("if err != nil {\nreturn err\n}\n")
   313  		g.printf("data = data[n:]\n")
   314  		g.printf("}\n")
   315  		return
   316  	}
   317  
   318  	tn := types.TypeString(t, types.RelativeTo(g.pkg))
   319  	ut := t.Underlying()
   320  	switch ut := ut.(type) {
   321  	case *types.Basic:
   322  		switch kind := ut.Kind(); kind {
   323  
   324  		case types.Bool:
   325  			g.printf("switch data[i] {\ncase 0:\n%s = false\n", n)
   326  			g.printf("default:\n%s = true\n}\n", n)
   327  			g.printf("data = data[1:]\n")
   328  
   329  		case types.Uint:
   330  			g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn)
   331  			g.printf("data = data[8:]\n")
   332  
   333  		case types.Uint8:
   334  			g.printf("%s = %s(data[0])\n", n, tn)
   335  			g.printf("data = data[1:]\n")
   336  
   337  		case types.Uint16:
   338  			g.printf("%s = %s(binary.LittleEndian.Uint16(data[:2]))\n", n, tn)
   339  			g.printf("data = data[2:]\n")
   340  
   341  		case types.Uint32:
   342  			g.printf("%s = %s(binary.LittleEndian.Uint32(data[:4]))\n", n, tn)
   343  			g.printf("data = data[4:]\n")
   344  
   345  		case types.Uint64:
   346  			g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn)
   347  			g.printf("data = data[8:]\n")
   348  
   349  		case types.Int:
   350  			g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn)
   351  			g.printf("data = data[8:]\n")
   352  
   353  		case types.Int8:
   354  			g.printf("%s = %s(data[0])\n", n, tn)
   355  			g.printf("data = data[1:]\n")
   356  
   357  		case types.Int16:
   358  			g.printf("%s = %s(binary.LittleEndian.Uint16(data[:2]))\n", n, tn)
   359  			g.printf("data = data[2:]\n")
   360  
   361  		case types.Int32:
   362  			g.printf("%s = %s(binary.LittleEndian.Uint32(data[:4]))\n", n, tn)
   363  			g.printf("data = data[4:]\n")
   364  
   365  		case types.Int64:
   366  			g.printf("%s = %s(binary.LittleEndian.Uint64(data[:8]))\n", n, tn)
   367  			g.printf("data = data[8:]\n")
   368  
   369  		case types.Float32:
   370  			g.imps["math"] = 1
   371  			g.printf("%s = %s(math.Float32frombits(binary.LittleEndian.Uint32(data[:4])))\n", n, tn)
   372  			g.printf("data = data[4:]\n")
   373  
   374  		case types.Float64:
   375  			g.imps["math"] = 1
   376  			g.printf("%s = %s(math.Float64frombits(binary.LittleEndian.Uint64(data[:8])))\n", n, tn)
   377  			g.printf("data = data[8:]\n")
   378  
   379  		case types.Complex64:
   380  			g.imps["math"] = 1
   381  			g.printf("%s = %s(complex(math.Float32frombits(binary.LittleEndian.Uint32(data[:4])), math.Float32frombits(binary.LittleEndian.Uint32(data[4:8]))))\n", n, tn)
   382  			g.printf("data = data[8:]\n")
   383  
   384  		case types.Complex128:
   385  			g.imps["math"] = 1
   386  			g.printf("%s = %s(complex(math.Float64frombits(binary.LittleEndian.Uint64(data[:8])), math.Float64frombits(binary.LittleEndian.Uint64(data[8:16]))))\n", n, tn)
   387  			g.printf("data = data[16:]\n")
   388  
   389  		case types.String:
   390  			g.printf("{\n")
   391  			g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n")
   392  			g.printf("data = data[8:]\n")
   393  			g.printf("%s = %s(data[:n])\n", n, tn)
   394  			g.printf("data = data[n:]\n")
   395  			g.printf("}\n")
   396  
   397  		default:
   398  			log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   399  		}
   400  
   401  	case *types.Struct:
   402  		switch t.(type) {
   403  		case *types.Named:
   404  			g.printf("{\n")
   405  			g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n")
   406  			g.printf("data = data[8:]\n")
   407  			g.printf("err = %s.UnmarshalBinary(data[:n])\n", n)
   408  			g.printf("if err != nil {\nreturn err\n}\n")
   409  			g.printf("data = data[n:]\n")
   410  			g.printf("}\n")
   411  		default:
   412  			// un-named.
   413  			for elem := range ut.Fields() {
   414  				g.genUnmarshalType(elem.Type(), n+"."+elem.Name())
   415  			}
   416  		}
   417  
   418  	case *types.Array:
   419  		if isByteType(ut.Elem()) {
   420  			g.printf("copy(%s[:], data[:n])\n", n)
   421  			g.printf("data = data[n:]\n")
   422  		} else {
   423  			g.printf("for i := range %s {\n", n)
   424  			nn := n + "[i]"
   425  			if pt, ok := ut.Elem().(*types.Pointer); ok {
   426  				g.printf("var oi %s\n", qualTypeName(pt.Elem(), g.pkg))
   427  				nn = "oi"
   428  			}
   429  			if _, ok := ut.Elem().Underlying().(*types.Struct); ok {
   430  				g.printf("oi := &%s[i]\n", n)
   431  				nn = "oi"
   432  			}
   433  			g.genUnmarshalType(ut.Elem(), nn)
   434  			if _, ok := ut.Elem().(*types.Pointer); ok {
   435  				g.printf("%s[i] = oi\n", n)
   436  			}
   437  			g.printf("}\n")
   438  		}
   439  
   440  	case *types.Slice:
   441  		g.printf("{\n")
   442  		g.printf("n := int(binary.LittleEndian.Uint64(data[:8]))\n")
   443  		g.printf("%[1]s = make([]%[2]s, n)\n", n, qualTypeName(ut.Elem(), g.pkg))
   444  		g.printf("data = data[8:]\n")
   445  		if isByteType(ut.Elem()) {
   446  			g.printf("%[1]s = append(%[1]s, data[:n]...)\n", n)
   447  			g.printf("data = data[n:]\n")
   448  		} else {
   449  			g.printf("for i := range %s {\n", n)
   450  			nn := n + "[i]"
   451  			if pt, ok := ut.Elem().(*types.Pointer); ok {
   452  				g.printf("var oi %s\n", qualTypeName(pt.Elem(), g.pkg))
   453  				nn = "oi"
   454  			}
   455  			if _, ok := ut.Elem().Underlying().(*types.Struct); ok {
   456  				g.printf("oi := &%s[i]\n", n)
   457  				nn = "oi"
   458  			}
   459  			g.genUnmarshalType(ut.Elem(), nn)
   460  			if _, ok := ut.Elem().(*types.Pointer); ok {
   461  				g.printf("%s[i] = oi\n", n)
   462  			}
   463  			g.printf("}\n")
   464  		}
   465  		g.printf("}\n")
   466  
   467  	case *types.Pointer:
   468  		g.printf("{\n")
   469  		elt := ut.Elem()
   470  		g.printf("var v %s\n", qualTypeName(elt, g.pkg))
   471  		g.genUnmarshalType(elt, "v")
   472  		g.printf("%s = &v\n\n", n)
   473  		g.printf("}\n")
   474  
   475  	case *types.Interface:
   476  		log.Fatalf("marshal interface not supported (type=%v)\n", t)
   477  
   478  	default:
   479  		log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   480  	}
   481  
   482  }
   483  
   484  func isByteType(t types.Type) bool {
   485  	b, ok := t.Underlying().(*types.Basic)
   486  	if !ok {
   487  		return false
   488  	}
   489  	return b.Kind() == types.Byte
   490  }
   491  
   492  func qualTypeName(t types.Type, pkg *types.Package) string {
   493  	n := types.TypeString(t, types.RelativeTo(pkg))
   494  	i := strings.LastIndex(n, "/")
   495  	if i < 0 {
   496  		return n
   497  	}
   498  	return string(n[i+1:])
   499  }
   500  
   501  func (g *Generator) Format() ([]byte, error) {
   502  	buf := new(bytes.Buffer)
   503  
   504  	// See standard at https://golang.org/s/generatedcode
   505  	buf.WriteString(fmt.Sprintf(`// Code generated by %[1]s; DO NOT EDIT.
   506  
   507  package %[2]s
   508  
   509  import (
   510  	"encoding/binary"
   511  `,
   512  		"brio-gen",
   513  		g.pkg.Name(),
   514  	))
   515  
   516  	for k := range g.imps {
   517  		fmt.Fprintf(buf, "%q\n", k)
   518  	}
   519  	fmt.Fprintf(buf, ")\n\n")
   520  
   521  	buf.Write(g.buf.Bytes())
   522  
   523  	src, err := format.Source(buf.Bytes())
   524  	if err != nil {
   525  		log.Printf("=== error ===\n%s\n", buf.Bytes())
   526  	}
   527  	return src, err
   528  }
   529  
   530  func importPkg(p string) (*types.Package, error) {
   531  	cfg := &packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedTypesSizes | packages.NeedDeps}
   532  	pkgs, err := packages.Load(cfg, p)
   533  	if err != nil {
   534  		return nil, fmt.Errorf("could not load package %q: %w", p, err)
   535  	}
   536  
   537  	return pkgs[0].Types, nil
   538  }
   539  
   540  func init() {
   541  	pkg, err := importPkg("encoding")
   542  	if err != nil {
   543  		log.Fatalf("error finding package \"encoding\": %v\n", err)
   544  	}
   545  
   546  	o := pkg.Scope().Lookup("BinaryMarshaler")
   547  	if o == nil {
   548  		log.Fatalf("could not find interface encoding.BinaryMarshaler\n")
   549  	}
   550  	binMa = o.(*types.TypeName).Type().Underlying().(*types.Interface)
   551  
   552  	o = pkg.Scope().Lookup("BinaryUnmarshaler")
   553  	if o == nil {
   554  		log.Fatalf("could not find interface encoding.BinaryUnmarshaler\n")
   555  	}
   556  	binUn = o.(*types.TypeName).Type().Underlying().(*types.Interface)
   557  }