gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/go_marshal/gomarshal/util.go (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package gomarshal 16 17 import ( 18 "bytes" 19 "flag" 20 "fmt" 21 "go/ast" 22 "go/token" 23 "io" 24 "os" 25 "path" 26 "reflect" 27 "sort" 28 "strings" 29 ) 30 31 var debug = flag.Bool("debug", false, "enables debugging output") 32 33 // receiverName returns an appropriate receiver name given a type spec. 34 func receiverName(t *ast.TypeSpec) string { 35 if len(t.Name.Name) < 1 { 36 // Zero length type name? 37 panic("unreachable") 38 } 39 return strings.ToLower(t.Name.Name[:1]) 40 } 41 42 // kindString returns a user-friendly representation of an AST expr type. 43 func kindString(e ast.Expr) string { 44 switch e.(type) { 45 case *ast.Ident: 46 return "scalar" 47 case *ast.ArrayType: 48 return "array" 49 case *ast.StructType: 50 return "struct" 51 case *ast.StarExpr: 52 return "pointer" 53 case *ast.FuncType: 54 return "function" 55 case *ast.InterfaceType: 56 return "interface" 57 case *ast.MapType: 58 return "map" 59 case *ast.ChanType: 60 return "channel" 61 default: 62 return reflect.TypeOf(e).String() 63 } 64 } 65 66 func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { 67 for _, field := range st.Fields.List { 68 fn(field) 69 } 70 } 71 72 // fieldDispatcher is a collection of callbacks for handling different types of 73 // fields in a struct declaration. 74 type fieldDispatcher struct { 75 primitive func(n, t *ast.Ident) 76 selector func(n, tX, tSel *ast.Ident) 77 array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) 78 unhandled func(n *ast.Ident) 79 } 80 81 // Precondition: All dispatch callbacks that will be invoked must be 82 // provided. 83 func (fd fieldDispatcher) dispatch(f *ast.Field) { 84 // Each field declaration may actually be multiple declarations of the same 85 // type. For example, consider: 86 // 87 // type Point struct { 88 // x, y, z int 89 // } 90 // 91 // We invoke the call-backs once per such instance. 92 93 // Handle embedded fields. Embedded fields have no names, but can be 94 // referenced by the type name. 95 if len(f.Names) < 1 { 96 switch v := f.Type.(type) { 97 case *ast.Ident: 98 fd.primitive(v, v) 99 case *ast.SelectorExpr: 100 fd.selector(v.Sel, v.X.(*ast.Ident), v.Sel) 101 default: 102 // Note: Arrays can't be embedded, which is handled here. 103 panic(fmt.Sprintf("Attempted to dispatch on embedded field of unsupported kind: %#v", f.Type)) 104 } 105 return 106 } 107 108 // Non-embedded field. 109 for _, name := range f.Names { 110 switch v := f.Type.(type) { 111 case *ast.Ident: 112 fd.primitive(name, v) 113 case *ast.SelectorExpr: 114 fd.selector(name, v.X.(*ast.Ident), v.Sel) 115 case *ast.ArrayType: 116 switch t := v.Elt.(type) { 117 case *ast.Ident: 118 fd.array(name, v, t) 119 default: 120 // Should be handled with a better error message during validate. 121 panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) 122 } 123 default: 124 fd.unhandled(name) 125 } 126 } 127 } 128 129 // debugEnabled indicates whether debugging is enabled for gomarshal. 130 func debugEnabled() bool { 131 return *debug 132 } 133 134 // abort aborts the go_marshal tool with the given error message. 135 func abort(msg string) { 136 if !strings.HasSuffix(msg, "\n") { 137 msg += "\n" 138 } 139 fmt.Print(msg) 140 os.Exit(1) 141 } 142 143 // abortAt aborts the go_marshal tool with the given error message, with 144 // a reference position to the input source. 145 func abortAt(p token.Position, msg string) { 146 abort(fmt.Sprintf("%v:\n %s\n", p, msg)) 147 } 148 149 // debugf conditionally prints a debug message. 150 func debugf(f string, a ...any) { 151 if debugEnabled() { 152 fmt.Printf(f, a...) 153 } 154 } 155 156 // debugfAt conditionally prints a debug message with a reference to a position 157 // in the input source. 158 func debugfAt(p token.Position, f string, a ...any) { 159 if debugEnabled() { 160 fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...)) 161 } 162 } 163 164 // emit generates a line of code in the output file. 165 // 166 // emit is a wrapper around writing a formatted string to the output 167 // buffer. emit can be invoked in one of two ways: 168 // 169 // (1) emit("some string") 170 // 171 // When emit is called with a single string argument, it is simply copied to 172 // the output buffer without any further formatting. 173 // 174 // (2) emit(fmtString, args...) 175 // 176 // emit can also be invoked in a similar fashion to *Printf() functions, 177 // where the first argument is a format string. 178 // 179 // Calling emit with a single argument that is not a string will result in a 180 // panic, as the caller's intent is ambiguous. 181 func emit(out io.Writer, indent int, a ...any) { 182 const spacesPerIndentLevel = 4 183 184 if len(a) < 1 { 185 panic("emit() called with no arguments") 186 } 187 188 if indent > 0 { 189 if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil { 190 // Writing to the emit output should not fail. Typically the output 191 // is a byte.Buffer; writes to these never fail. 192 panic(err) 193 } 194 } 195 196 first, ok := a[0].(string) 197 if !ok { 198 // First argument must be either the string to emit (case 1 from 199 // function-level comment), or a format string (case 2). 200 panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0])) 201 } 202 203 if len(a) == 1 { 204 // Single string argument. Assume no formatting requested. 205 if _, err := fmt.Fprint(out, first); err != nil { 206 // Writing to out should not fail. 207 panic(err) 208 } 209 return 210 211 } 212 213 // Formatting requested. 214 if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil { 215 // Writing to out should not fail. 216 panic(err) 217 } 218 } 219 220 // sourceBuffer represents fragments of generated go source code. 221 // 222 // sourceBuffer provides a convenient way to build up go souce fragments in 223 // memory. May be safely zero-value initialized. Not thread-safe. 224 type sourceBuffer struct { 225 // Current indentation level. 226 indent int 227 228 // Memory buffer containing contents while they're being generated. 229 b bytes.Buffer 230 } 231 232 func (b *sourceBuffer) reset() { 233 b.indent = 0 234 b.b.Reset() 235 } 236 237 func (b *sourceBuffer) incIndent() { 238 b.indent++ 239 } 240 241 func (b *sourceBuffer) decIndent() { 242 if b.indent <= 0 { 243 panic("decIndent() without matching incIndent()") 244 } 245 b.indent-- 246 } 247 248 func (b *sourceBuffer) emit(a ...any) { 249 emit(&b.b, b.indent, a...) 250 } 251 252 func (b *sourceBuffer) emitNoIndent(a ...any) { 253 emit(&b.b, 0 /*indent*/, a...) 254 } 255 256 func (b *sourceBuffer) inIndent(body func()) { 257 b.incIndent() 258 body() 259 b.decIndent() 260 } 261 262 func (b *sourceBuffer) write(out io.Writer) error { 263 _, err := fmt.Fprint(out, b.b.String()) 264 return err 265 } 266 267 // Write implements io.Writer.Write. 268 func (b *sourceBuffer) Write(buf []byte) (int, error) { 269 return (b.b.Write(buf)) 270 } 271 272 // importStmt represents a single import statement. 273 type importStmt struct { 274 // Local name of the imported package. 275 name string 276 // Import path. 277 path string 278 // Indicates whether the local name is an alias, or simply the final 279 // component of the path. 280 aliased bool 281 // Indicates whether this import was referenced by generated code. 282 used bool 283 // AST node and file set representing the import statement, if any. These 284 // are only non-nil if the import statement originates from an input source 285 // file. 286 spec *ast.ImportSpec 287 fset *token.FileSet 288 } 289 290 func newImport(p string) *importStmt { 291 name := path.Base(p) 292 return &importStmt{ 293 name: name, 294 path: p, 295 aliased: false, 296 } 297 } 298 299 func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { 300 p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path. 301 name := path.Base(p) 302 if name == "" || name == "/" || name == "." { 303 panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)", 304 f.Position(spec.Path.Pos()), name)) 305 } 306 if spec.Name != nil { 307 name = spec.Name.Name 308 } 309 return &importStmt{ 310 name: name, 311 path: p, 312 aliased: spec.Name != nil, 313 spec: spec, 314 fset: f, 315 } 316 } 317 318 // String implements fmt.Stringer.String. This generates a string for the import 319 // statement appropriate for writing directly to generated code. 320 func (i *importStmt) String() string { 321 if i.aliased { 322 return fmt.Sprintf("%s %q", i.name, i.path) 323 } 324 return fmt.Sprintf("%q", i.path) 325 } 326 327 // debugString returns a debug string representing an import statement. This 328 // representation is not valid golang code and is used for debugging output. 329 func (i *importStmt) debugString() string { 330 if i.spec != nil && i.fset != nil { 331 return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i) 332 } 333 return fmt.Sprintf("(go-marshal import): %s", i) 334 } 335 336 func (i *importStmt) markUsed() { 337 i.used = true 338 } 339 340 func (i *importStmt) equivalent(other *importStmt) bool { 341 return i.name == other.name && i.path == other.path && i.aliased == other.aliased 342 } 343 344 // importTable represents a collection of importStmts. 345 // 346 // An importTable may contain multiple import statements referencing the same 347 // local name. All import statements aliasing to the same local name are 348 // technically ambiguous, as if such an import name is used in the generated 349 // code, it's not clear which import statement it refers to. We ignore any 350 // potential collisions until actually writing the import table to the generated 351 // source file. See importTable.write. 352 // 353 // Given the following import statements across all the files comprising a 354 // package marshalled: 355 // 356 // "sync" 357 // "pkg/sync" 358 // "pkg/sentry/kernel" 359 // ktime "pkg/sentry/kernel/time" 360 // 361 // An importTable representing them would look like this: 362 // 363 // importTable { 364 // is: map[string][]*importStmt { 365 // "sync": []*importStmt{ 366 // importStmt{name:"sync", path:"sync", aliased:false} 367 // importStmt{name:"sync", path:"pkg/sync", aliased:false} 368 // }, 369 // "kernel": []*importStmt{importStmt{ 370 // name: "kernel", 371 // path: "pkg/sentry/kernel", 372 // aliased: false 373 // }}, 374 // "ktime": []*importStmt{importStmt{ 375 // name: "ktime", 376 // path: "pkg/sentry/kernel/time", 377 // aliased: true, 378 // }}, 379 // } 380 // } 381 // 382 // Note that the local name "sync" is assigned to two different import 383 // statements. This is possible if the import statements are from different 384 // source files in the same package. 385 // 386 // Since go-marshal generates a single output file per package regardless of the 387 // number of input files, if "sync" is referenced by any generated code, it's 388 // unclear which import statement "sync" refers to. While it's theoretically 389 // possible to resolve this by assigning a unique local alias to each instance 390 // of the sync package, go-marshal currently aborts when it encounters such an 391 // ambiguity. 392 // 393 // TODO(b/151478251): importTable considers the final component of an import 394 // path to be the package name, but this is only a convention. The actual 395 // package name is determined by the package statement in the source files for 396 // the package. 397 type importTable struct { 398 // Map of imports and whether they should be copied to the output. 399 is map[string][]*importStmt 400 } 401 402 func newImportTable() *importTable { 403 return &importTable{ 404 is: make(map[string][]*importStmt), 405 } 406 } 407 408 // Merges import statements from other into i. 409 func (i *importTable) merge(other *importTable) { 410 for name, ims := range other.is { 411 i.is[name] = append(i.is[name], ims...) 412 } 413 } 414 415 func (i *importTable) addStmt(s *importStmt) *importStmt { 416 i.is[s.name] = append(i.is[s.name], s) 417 return s 418 } 419 420 func (i *importTable) add(s string) *importStmt { 421 n := newImport(s) 422 return i.addStmt(n) 423 } 424 425 func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { 426 return i.addStmt(newImportFromSpec(spec, f)) 427 } 428 429 // Marks the import named n as used. If no such import is in the table, returns 430 // false. 431 func (i *importTable) markUsed(n string) bool { 432 if ns, ok := i.is[n]; ok { 433 for _, n := range ns { 434 n.markUsed() 435 } 436 return true 437 } 438 return false 439 } 440 441 func (i *importTable) clear() { 442 for _, is := range i.is { 443 for _, i := range is { 444 i.used = false 445 } 446 } 447 } 448 449 func (i *importTable) write(out io.Writer) error { 450 if len(i.is) == 0 { 451 // Nothing to import, we're done. 452 return nil 453 } 454 455 imports := make([]string, 0, len(i.is)) 456 for name, is := range i.is { 457 var lastUsed *importStmt 458 var ambiguous bool 459 460 for _, i := range is { 461 if i.used { 462 if lastUsed != nil { 463 if !i.equivalent(lastUsed) { 464 ambiguous = true 465 } 466 } 467 lastUsed = i 468 } 469 } 470 471 if ambiguous { 472 // We have two or more import statements across the different source 473 // files that share a local name, and at least one of these imports 474 // are used by the generated code. This ambiguity can't be resolved 475 // by go-marshal and requires the user intervention. Dump a list of 476 // the colliding import statements and let the user modify the input 477 // files as appropriate. 478 var b strings.Builder 479 fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name) 480 fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name) 481 // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't 482 // be true. Therefore the slicing below is safe. 483 for _, i := range is[:len(is)-1] { 484 fmt.Fprintf(&b, " %v\n", i.debugString()) 485 } 486 fmt.Fprintf(&b, " %v", is[len(is)-1].debugString()) 487 panic(b.String()) 488 } 489 490 if lastUsed != nil { 491 imports = append(imports, lastUsed.String()) 492 } 493 } 494 sort.Strings(imports) 495 496 var b sourceBuffer 497 b.emit("import (\n") 498 b.incIndent() 499 for _, i := range imports { 500 b.emit("%s\n", i) 501 } 502 b.decIndent() 503 b.emit(")\n\n") 504 505 return b.write(out) 506 }