go-hep.org/x/hep@v0.38.1/xrootd/xrdproto/gen-marshal.go (about)

     1  // Copyright ©2018 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build ignore
     6  
     7  package main
     8  
     9  import (
    10  	"bytes"
    11  	"flag"
    12  	"fmt"
    13  	"go/format"
    14  	"go/types"
    15  	"io"
    16  	"log"
    17  	"os"
    18  	"strings"
    19  
    20  	"golang.org/x/tools/go/packages"
    21  )
    22  
    23  func main() {
    24  	var (
    25  		typeNames = flag.String("t", "", "comma-separated list of type names")
    26  		pkgPath   = flag.String("p", "", "package import path")
    27  	)
    28  
    29  	flag.Parse()
    30  
    31  	log.SetPrefix("gen-xrd: ")
    32  	log.SetFlags(0)
    33  
    34  	if *typeNames == "" {
    35  		flag.Usage()
    36  		os.Exit(2)
    37  	}
    38  
    39  	types := strings.Split(*typeNames, ",")
    40  	g, err := NewGenerator(*pkgPath)
    41  	if err != nil {
    42  		log.Fatal(err)
    43  	}
    44  
    45  	for _, t := range types {
    46  		g.Generate(t)
    47  	}
    48  
    49  	buf, err := g.Format()
    50  	if err != nil {
    51  		log.Fatalf("gofmt: %v\n", err)
    52  	}
    53  
    54  	_, err = io.Copy(os.Stdout, bytes.NewReader(buf))
    55  	if err != nil {
    56  		log.Fatalf("error generating (un)marshaler code: %v\n", err)
    57  	}
    58  }
    59  
    60  // Generator holds the state of the generation.
    61  type Generator struct {
    62  	buf *bytes.Buffer
    63  	pkg *types.Package
    64  
    65  	Verbose bool // enable verbose mode
    66  }
    67  
    68  // NewGenerator returns a new code generator for package p,
    69  // where p is the package's import path.
    70  func NewGenerator(p string) (*Generator, error) {
    71  	pkg, err := importPkg(p)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	return &Generator{
    77  		buf: new(bytes.Buffer),
    78  		pkg: pkg,
    79  	}, nil
    80  }
    81  
    82  func (g *Generator) printf(format string, args ...any) {
    83  	fmt.Fprintf(g.buf, format, args...)
    84  }
    85  
    86  func (g *Generator) Generate(typeName string) {
    87  	scope := g.pkg.Scope()
    88  	obj := scope.Lookup(typeName)
    89  	if obj == nil {
    90  		log.Fatalf("no such type %q in package %q\n", typeName, g.pkg.Path()+"/"+g.pkg.Name())
    91  	}
    92  
    93  	tn, ok := obj.(*types.TypeName)
    94  	if !ok {
    95  		log.Fatalf("%q is not a type (%v)\n", typeName, obj)
    96  	}
    97  
    98  	typ, ok := tn.Type().Underlying().(*types.Struct)
    99  	if !ok {
   100  		log.Fatalf("%q is not a named struct (%v)\n", typeName, tn)
   101  	}
   102  	if g.Verbose {
   103  		log.Printf("typ: %+v\n", typ)
   104  	}
   105  
   106  	g.genMarshalXrd(typ, typeName)
   107  	g.genUnmarshalXrd(typ, typeName)
   108  }
   109  
   110  func (g *Generator) genMarshalXrd(t types.Type, typeName string) {
   111  	g.printf(`// MarshalXrd implements xrdproto.Marshaler
   112  func (o %[1]s) MarshalXrd(wBuffer *xrdenc.WBuffer) error {
   113  `,
   114  		typeName,
   115  	)
   116  
   117  	typ := t.Underlying().(*types.Struct)
   118  	for i := 0; i < typ.NumFields(); i++ {
   119  		ft := typ.Field(i)
   120  		g.genMarshalType(ft.Type(), "o."+ft.Name())
   121  	}
   122  
   123  	g.printf("return nil\n}\n\n")
   124  }
   125  
   126  func (g *Generator) genMarshalType(t types.Type, n string) {
   127  	ut := t.Underlying()
   128  	switch ut := ut.(type) {
   129  	case *types.Basic:
   130  		switch kind := ut.Kind(); kind {
   131  
   132  		case types.Bool:
   133  			g.printf("wBuffer.WriteBool(%s)\n", g.upcasted(t, n))
   134  
   135  		case types.Uint8:
   136  			if n == "o._" {
   137  				g.printf("wBuffer.Next(1)\n")
   138  			} else {
   139  				g.printf("wBuffer.WriteU8(%s)\n", g.upcasted(t, n))
   140  			}
   141  
   142  		case types.Uint16:
   143  			g.printf("wBuffer.WriteU16(%s)\n", g.upcasted(t, n))
   144  
   145  		case types.Uint32:
   146  			g.printf("wBuffer.WriteI32(int32(%s))\n", g.upcasted(t, n))
   147  
   148  		case types.Uint64:
   149  			g.printf("wBuffer.WriteI64(int64(%s))\n", g.upcasted(t, n))
   150  
   151  		case types.Int8:
   152  			g.printf("wBuffer.WriteU8(uint8(%s))\n", g.upcasted(t, n))
   153  
   154  		case types.Int16:
   155  			g.printf("wBuffer.WriteU16(uint16(%s))\n", g.upcasted(t, n))
   156  
   157  		case types.Int32:
   158  			if n == "o._" {
   159  				g.printf("wBuffer.Next(4)\n")
   160  			} else {
   161  				g.printf("wBuffer.WriteI32(%s)\n", g.upcasted(t, n))
   162  			}
   163  
   164  		case types.Int64:
   165  			g.printf("wBuffer.WriteI64(%s)\n", g.upcasted(t, n))
   166  
   167  		case types.String:
   168  			g.printf("wBuffer.WriteStr(%s)\n", g.upcasted(t, n))
   169  
   170  		default:
   171  			log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   172  		}
   173  
   174  	case *types.Struct:
   175  		g.printf("if err := %s.MarshalXrd(wBuffer); err != nil {\nreturn err\n}\n", n)
   176  
   177  	case *types.Array:
   178  		if !isByteType(ut.Elem()) {
   179  			log.Fatalf("marshal array of type %v not supported", ut)
   180  		}
   181  		if n == "o._" {
   182  			g.printf("wBuffer.Next(%d)\n", ut.Len())
   183  		} else {
   184  			g.printf("wBuffer.WriteBytes(%s[:])\n", n)
   185  		}
   186  
   187  	case *types.Slice:
   188  		if !isByteType(ut.Elem()) {
   189  			g.printf("wBuffer.WriteLen(len(%s))\n", n)
   190  			g.printf(`for _, x := range %s {
   191  	err := x.MarshalXrd(wBuffer)
   192  	if err != nil {
   193  		return err
   194  	}
   195  }
   196  `, n)
   197  		} else {
   198  			g.printf("wBuffer.WriteLen(len(%s))\n", n)
   199  			g.printf("wBuffer.WriteBytes(%s)\n", n)
   200  		}
   201  
   202  	default:
   203  		log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   204  	}
   205  }
   206  
   207  func (g *Generator) genUnmarshalXrd(t types.Type, typeName string) {
   208  	g.printf(`// UnmarshalXrd implements xrdproto.Unmarshaler
   209  func (o *%[1]s) UnmarshalXrd(rBuffer *xrdenc.RBuffer) error {
   210  `,
   211  		typeName,
   212  	)
   213  
   214  	typ := t.Underlying().(*types.Struct)
   215  	for i := 0; i < typ.NumFields(); i++ {
   216  		ft := typ.Field(i)
   217  		g.genUnmarshalType(ft.Type(), "o."+ft.Name())
   218  	}
   219  
   220  	g.printf("return nil\n}\n\n")
   221  }
   222  
   223  func (g *Generator) downcasted(t types.Type, expression string) string {
   224  	if named, ok := t.(*types.Named); ok {
   225  		cast := qualTypeName(named, g.pkg)
   226  		return cast + "(" + expression + ")"
   227  	}
   228  	return expression
   229  }
   230  
   231  func (g *Generator) upcasted(t types.Type, expression string) string {
   232  	if named, ok := t.(*types.Named); ok {
   233  		ut := named.Underlying()
   234  		if basic, ok := ut.(*types.Basic); ok {
   235  			cast := basic.Name()
   236  			return cast + "(" + expression + ")"
   237  		}
   238  	}
   239  	return expression
   240  }
   241  
   242  func (g *Generator) genUnmarshalType(t types.Type, n string) {
   243  	ut := t.Underlying()
   244  	switch ut := ut.(type) {
   245  	case *types.Basic:
   246  		switch kind := ut.Kind(); kind {
   247  		case types.Bool:
   248  			g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadBool()"))
   249  
   250  		case types.Uint:
   251  			g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "uint(rBuffer.ReadI64())"))
   252  
   253  		case types.Uint8:
   254  			if n == "o._" {
   255  				g.printf("rBuffer.Skip(1)\n")
   256  			} else {
   257  				g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadU8()"))
   258  			}
   259  
   260  		case types.Uint16:
   261  			g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadU16()"))
   262  
   263  		case types.Uint32:
   264  			g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "uint32(rBuffer.ReadI32())"))
   265  
   266  		case types.Uint64:
   267  			g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "uint64(rBuffer.ReadI64())"))
   268  
   269  		case types.Int8:
   270  			g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "int8(rBuffer.ReadU8())"))
   271  		case types.Int16:
   272  			g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "int16(rBuffer.ReadU16())"))
   273  
   274  		case types.Int32:
   275  			if n == "o._" {
   276  				g.printf("rBuffer.Skip(4)\n")
   277  			} else {
   278  				g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadI32()"))
   279  			}
   280  
   281  		case types.Int64:
   282  			g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadI64()"))
   283  
   284  		case types.String:
   285  			g.printf("%s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadStr()"))
   286  
   287  		default:
   288  			log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   289  		}
   290  
   291  	case *types.Struct:
   292  		g.printf("if err := %s.UnmarshalXrd(rBuffer); err != nil {\n return err\n}\n", n)
   293  
   294  	case *types.Array:
   295  		if !isByteType(ut.Elem()) {
   296  			log.Fatalf("unmarshal array of type %v not supported", ut)
   297  		}
   298  		if n == "o._" {
   299  			g.printf("rBuffer.Skip(%d)\n", ut.Len())
   300  		} else {
   301  			g.printf("rBuffer.ReadBytes(%s[:])\n", n)
   302  		}
   303  
   304  	case *types.Slice:
   305  		if !isByteType(ut.Elem()) {
   306  			g.printf("%[1]s = make([]%[2]s, rBuffer.ReadLen())\n", n, qualTypeName(ut.Elem(), g.pkg))
   307  			g.printf(`for i:=0; i<len(%[1]s); i++ {
   308      err := %[1]s[i].UnmarshalXrd(rBuffer)
   309      if err != nil {
   310  		return err
   311      }
   312  }
   313  `, n)
   314  		} else {
   315  			g.printf("%[1]s = make([]%[2]s, rBuffer.ReadLen())\n", n, qualTypeName(ut.Elem(), g.pkg))
   316  			g.printf("rBuffer.ReadBytes(%s)\n", n)
   317  		}
   318  
   319  	default:
   320  		log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut)
   321  	}
   322  
   323  }
   324  
   325  func isByteType(t types.Type) bool {
   326  	b, ok := t.Underlying().(*types.Basic)
   327  	if !ok {
   328  		return false
   329  	}
   330  	return b.Kind() == types.Byte
   331  }
   332  
   333  func qualTypeName(t types.Type, pkg *types.Package) string {
   334  	n := types.TypeString(t, types.RelativeTo(pkg))
   335  	i := strings.LastIndex(n, "/")
   336  	if i < 0 {
   337  		return n
   338  	}
   339  	return string(n[i+1:])
   340  }
   341  
   342  func (g *Generator) Format() ([]byte, error) {
   343  	buf := new(bytes.Buffer)
   344  
   345  	buf.Write(g.buf.Bytes())
   346  
   347  	src, err := format.Source(buf.Bytes())
   348  	if err != nil {
   349  		log.Printf("=== error ===\n%s\n", buf.Bytes())
   350  	}
   351  	return src, err
   352  }
   353  
   354  func importPkg(p string) (*types.Package, error) {
   355  	cfg := &packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedTypesSizes | packages.NeedDeps}
   356  	pkgs, err := packages.Load(cfg, p)
   357  	if err != nil {
   358  		return nil, fmt.Errorf("could not load package %q: %w", p, err)
   359  	}
   360  
   361  	return pkgs[0].Types, nil
   362  }