go-hep.org/x/hep@v0.38.1/xrootd/xrdproto/gen-marshal.go (about) 1 // Copyright ©2018 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 //go:build ignore 6 7 package main 8 9 import ( 10 "bytes" 11 "flag" 12 "fmt" 13 "go/format" 14 "go/types" 15 "io" 16 "log" 17 "os" 18 "strings" 19 20 "golang.org/x/tools/go/packages" 21 ) 22 23 func main() { 24 var ( 25 typeNames = flag.String("t", "", "comma-separated list of type names") 26 pkgPath = flag.String("p", "", "package import path") 27 ) 28 29 flag.Parse() 30 31 log.SetPrefix("gen-xrd: ") 32 log.SetFlags(0) 33 34 if *typeNames == "" { 35 flag.Usage() 36 os.Exit(2) 37 } 38 39 types := strings.Split(*typeNames, ",") 40 g, err := NewGenerator(*pkgPath) 41 if err != nil { 42 log.Fatal(err) 43 } 44 45 for _, t := range types { 46 g.Generate(t) 47 } 48 49 buf, err := g.Format() 50 if err != nil { 51 log.Fatalf("gofmt: %v\n", err) 52 } 53 54 _, err = io.Copy(os.Stdout, bytes.NewReader(buf)) 55 if err != nil { 56 log.Fatalf("error generating (un)marshaler code: %v\n", err) 57 } 58 } 59 60 // Generator holds the state of the generation. 61 type Generator struct { 62 buf *bytes.Buffer 63 pkg *types.Package 64 65 Verbose bool // enable verbose mode 66 } 67 68 // NewGenerator returns a new code generator for package p, 69 // where p is the package's import path. 70 func NewGenerator(p string) (*Generator, error) { 71 pkg, err := importPkg(p) 72 if err != nil { 73 return nil, err 74 } 75 76 return &Generator{ 77 buf: new(bytes.Buffer), 78 pkg: pkg, 79 }, nil 80 } 81 82 func (g *Generator) printf(format string, args ...any) { 83 fmt.Fprintf(g.buf, format, args...) 84 } 85 86 func (g *Generator) Generate(typeName string) { 87 scope := g.pkg.Scope() 88 obj := scope.Lookup(typeName) 89 if obj == nil { 90 log.Fatalf("no such type %q in package %q\n", typeName, g.pkg.Path()+"/"+g.pkg.Name()) 91 } 92 93 tn, ok := obj.(*types.TypeName) 94 if !ok { 95 log.Fatalf("%q is not a type (%v)\n", typeName, obj) 96 } 97 98 typ, ok := tn.Type().Underlying().(*types.Struct) 99 if !ok { 100 log.Fatalf("%q is not a named struct (%v)\n", typeName, tn) 101 } 102 if g.Verbose { 103 log.Printf("typ: %+v\n", typ) 104 } 105 106 g.genMarshalXrd(typ, typeName) 107 g.genUnmarshalXrd(typ, typeName) 108 } 109 110 func (g *Generator) genMarshalXrd(t types.Type, typeName string) { 111 g.printf(`// MarshalXrd implements xrdproto.Marshaler 112 func (o %[1]s) MarshalXrd(wBuffer *xrdenc.WBuffer) error { 113 `, 114 typeName, 115 ) 116 117 typ := t.Underlying().(*types.Struct) 118 for i := 0; i < typ.NumFields(); i++ { 119 ft := typ.Field(i) 120 g.genMarshalType(ft.Type(), "o."+ft.Name()) 121 } 122 123 g.printf("return nil\n}\n\n") 124 } 125 126 func (g *Generator) genMarshalType(t types.Type, n string) { 127 ut := t.Underlying() 128 switch ut := ut.(type) { 129 case *types.Basic: 130 switch kind := ut.Kind(); kind { 131 132 case types.Bool: 133 g.printf("wBuffer.WriteBool(%s)\n", g.upcasted(t, n)) 134 135 case types.Uint8: 136 if n == "o._" { 137 g.printf("wBuffer.Next(1)\n") 138 } else { 139 g.printf("wBuffer.WriteU8(%s)\n", g.upcasted(t, n)) 140 } 141 142 case types.Uint16: 143 g.printf("wBuffer.WriteU16(%s)\n", g.upcasted(t, n)) 144 145 case types.Uint32: 146 g.printf("wBuffer.WriteI32(int32(%s))\n", g.upcasted(t, n)) 147 148 case types.Uint64: 149 g.printf("wBuffer.WriteI64(int64(%s))\n", g.upcasted(t, n)) 150 151 case types.Int8: 152 g.printf("wBuffer.WriteU8(uint8(%s))\n", g.upcasted(t, n)) 153 154 case types.Int16: 155 g.printf("wBuffer.WriteU16(uint16(%s))\n", g.upcasted(t, n)) 156 157 case types.Int32: 158 if n == "o._" { 159 g.printf("wBuffer.Next(4)\n") 160 } else { 161 g.printf("wBuffer.WriteI32(%s)\n", g.upcasted(t, n)) 162 } 163 164 case types.Int64: 165 g.printf("wBuffer.WriteI64(%s)\n", g.upcasted(t, n)) 166 167 case types.String: 168 g.printf("wBuffer.WriteStr(%s)\n", g.upcasted(t, n)) 169 170 default: 171 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 172 } 173 174 case *types.Struct: 175 g.printf("if err := %s.MarshalXrd(wBuffer); err != nil {\nreturn err\n}\n", n) 176 177 case *types.Array: 178 if !isByteType(ut.Elem()) { 179 log.Fatalf("marshal array of type %v not supported", ut) 180 } 181 if n == "o._" { 182 g.printf("wBuffer.Next(%d)\n", ut.Len()) 183 } else { 184 g.printf("wBuffer.WriteBytes(%s[:])\n", n) 185 } 186 187 case *types.Slice: 188 if !isByteType(ut.Elem()) { 189 g.printf("wBuffer.WriteLen(len(%s))\n", n) 190 g.printf(`for _, x := range %s { 191 err := x.MarshalXrd(wBuffer) 192 if err != nil { 193 return err 194 } 195 } 196 `, n) 197 } else { 198 g.printf("wBuffer.WriteLen(len(%s))\n", n) 199 g.printf("wBuffer.WriteBytes(%s)\n", n) 200 } 201 202 default: 203 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 204 } 205 } 206 207 func (g *Generator) genUnmarshalXrd(t types.Type, typeName string) { 208 g.printf(`// UnmarshalXrd implements xrdproto.Unmarshaler 209 func (o *%[1]s) UnmarshalXrd(rBuffer *xrdenc.RBuffer) error { 210 `, 211 typeName, 212 ) 213 214 typ := t.Underlying().(*types.Struct) 215 for i := 0; i < typ.NumFields(); i++ { 216 ft := typ.Field(i) 217 g.genUnmarshalType(ft.Type(), "o."+ft.Name()) 218 } 219 220 g.printf("return nil\n}\n\n") 221 } 222 223 func (g *Generator) downcasted(t types.Type, expression string) string { 224 if named, ok := t.(*types.Named); ok { 225 cast := qualTypeName(named, g.pkg) 226 return cast + "(" + expression + ")" 227 } 228 return expression 229 } 230 231 func (g *Generator) upcasted(t types.Type, expression string) string { 232 if named, ok := t.(*types.Named); ok { 233 ut := named.Underlying() 234 if basic, ok := ut.(*types.Basic); ok { 235 cast := basic.Name() 236 return cast + "(" + expression + ")" 237 } 238 } 239 return expression 240 } 241 242 func (g *Generator) genUnmarshalType(t types.Type, n string) { 243 ut := t.Underlying() 244 switch ut := ut.(type) { 245 case *types.Basic: 246 switch kind := ut.Kind(); kind { 247 case types.Bool: 248 g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadBool()")) 249 250 case types.Uint: 251 g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "uint(rBuffer.ReadI64())")) 252 253 case types.Uint8: 254 if n == "o._" { 255 g.printf("rBuffer.Skip(1)\n") 256 } else { 257 g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadU8()")) 258 } 259 260 case types.Uint16: 261 g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadU16()")) 262 263 case types.Uint32: 264 g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "uint32(rBuffer.ReadI32())")) 265 266 case types.Uint64: 267 g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "uint64(rBuffer.ReadI64())")) 268 269 case types.Int8: 270 g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "int8(rBuffer.ReadU8())")) 271 case types.Int16: 272 g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "int16(rBuffer.ReadU16())")) 273 274 case types.Int32: 275 if n == "o._" { 276 g.printf("rBuffer.Skip(4)\n") 277 } else { 278 g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadI32()")) 279 } 280 281 case types.Int64: 282 g.printf("%[1]s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadI64()")) 283 284 case types.String: 285 g.printf("%s = %[2]s\n", n, g.downcasted(t, "rBuffer.ReadStr()")) 286 287 default: 288 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 289 } 290 291 case *types.Struct: 292 g.printf("if err := %s.UnmarshalXrd(rBuffer); err != nil {\n return err\n}\n", n) 293 294 case *types.Array: 295 if !isByteType(ut.Elem()) { 296 log.Fatalf("unmarshal array of type %v not supported", ut) 297 } 298 if n == "o._" { 299 g.printf("rBuffer.Skip(%d)\n", ut.Len()) 300 } else { 301 g.printf("rBuffer.ReadBytes(%s[:])\n", n) 302 } 303 304 case *types.Slice: 305 if !isByteType(ut.Elem()) { 306 g.printf("%[1]s = make([]%[2]s, rBuffer.ReadLen())\n", n, qualTypeName(ut.Elem(), g.pkg)) 307 g.printf(`for i:=0; i<len(%[1]s); i++ { 308 err := %[1]s[i].UnmarshalXrd(rBuffer) 309 if err != nil { 310 return err 311 } 312 } 313 `, n) 314 } else { 315 g.printf("%[1]s = make([]%[2]s, rBuffer.ReadLen())\n", n, qualTypeName(ut.Elem(), g.pkg)) 316 g.printf("rBuffer.ReadBytes(%s)\n", n) 317 } 318 319 default: 320 log.Fatalf("unhandled type: %v (underlying: %v)\n", t, ut) 321 } 322 323 } 324 325 func isByteType(t types.Type) bool { 326 b, ok := t.Underlying().(*types.Basic) 327 if !ok { 328 return false 329 } 330 return b.Kind() == types.Byte 331 } 332 333 func qualTypeName(t types.Type, pkg *types.Package) string { 334 n := types.TypeString(t, types.RelativeTo(pkg)) 335 i := strings.LastIndex(n, "/") 336 if i < 0 { 337 return n 338 } 339 return string(n[i+1:]) 340 } 341 342 func (g *Generator) Format() ([]byte, error) { 343 buf := new(bytes.Buffer) 344 345 buf.Write(g.buf.Bytes()) 346 347 src, err := format.Source(buf.Bytes()) 348 if err != nil { 349 log.Printf("=== error ===\n%s\n", buf.Bytes()) 350 } 351 return src, err 352 } 353 354 func importPkg(p string) (*types.Package, error) { 355 cfg := &packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedTypesSizes | packages.NeedDeps} 356 pkgs, err := packages.Load(cfg, p) 357 if err != nil { 358 return nil, fmt.Errorf("could not load package %q: %w", p, err) 359 } 360 361 return pkgs[0].Types, nil 362 }