github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/go_marshal/gomarshal/util.go (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package gomarshal 16 17 import ( 18 "bytes" 19 "flag" 20 "fmt" 21 "go/ast" 22 "go/token" 23 "io" 24 "os" 25 "path" 26 "reflect" 27 "sort" 28 "strings" 29 ) 30 31 var debug = flag.Bool("debug", false, "enables debugging output") 32 33 // receiverName returns an appropriate receiver name given a type spec. 34 func receiverName(t *ast.TypeSpec) string { 35 if len(t.Name.Name) < 1 { 36 // Zero length type name? 37 panic("unreachable") 38 } 39 return strings.ToLower(t.Name.Name[:1]) 40 } 41 42 // kindString returns a user-friendly representation of an AST expr type. 43 func kindString(e ast.Expr) string { 44 switch e.(type) { 45 case *ast.Ident: 46 return "scalar" 47 case *ast.ArrayType: 48 return "array" 49 case *ast.StructType: 50 return "struct" 51 case *ast.StarExpr: 52 return "pointer" 53 case *ast.FuncType: 54 return "function" 55 case *ast.InterfaceType: 56 return "interface" 57 case *ast.MapType: 58 return "map" 59 case *ast.ChanType: 60 return "channel" 61 default: 62 return reflect.TypeOf(e).String() 63 } 64 } 65 66 func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { 67 for _, field := range st.Fields.List { 68 fn(field) 69 } 70 } 71 72 // fieldDispatcher is a collection of callbacks for handling different types of 73 // fields in a struct declaration. 74 type fieldDispatcher struct { 75 primitive func(n, t *ast.Ident) 76 selector func(n, tX, tSel *ast.Ident) 77 array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) 78 unhandled func(n *ast.Ident) 79 } 80 81 // Precondition: All dispatch callbacks that will be invoked must be 82 // provided. 83 func (fd fieldDispatcher) dispatch(f *ast.Field) { 84 // Each field declaration may actually be multiple declarations of the same 85 // type. For example, consider: 86 // 87 // type Point struct { 88 // x, y, z int 89 // } 90 // 91 // We invoke the call-backs once per such instance. 92 93 // Handle embedded fields. Embedded fields have no names, but can be 94 // referenced by the type name. 95 if len(f.Names) < 1 { 96 switch v := f.Type.(type) { 97 case *ast.Ident: 98 fd.primitive(v, v) 99 case *ast.SelectorExpr: 100 fd.selector(v.Sel, v.X.(*ast.Ident), v.Sel) 101 default: 102 // Note: Arrays can't be embedded, which is handled here. 103 panic(fmt.Sprintf("Attempted to dispatch on embedded field of unsupported kind: %#v", f.Type)) 104 } 105 return 106 } 107 108 // Non-embedded field. 109 for _, name := range f.Names { 110 switch v := f.Type.(type) { 111 case *ast.Ident: 112 fd.primitive(name, v) 113 case *ast.SelectorExpr: 114 fd.selector(name, v.X.(*ast.Ident), v.Sel) 115 case *ast.ArrayType: 116 switch t := v.Elt.(type) { 117 case *ast.Ident: 118 fd.array(name, v, t) 119 default: 120 // Should be handled with a better error message during validate. 121 panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) 122 } 123 default: 124 fd.unhandled(name) 125 } 126 } 127 } 128 129 // debugEnabled indicates whether debugging is enabled for gomarshal. 130 func debugEnabled() bool { 131 return *debug 132 } 133 134 // abort aborts the go_marshal tool with the given error message. 135 func abort(msg string) { 136 if !strings.HasSuffix(msg, "\n") { 137 msg += "\n" 138 } 139 fmt.Print(msg) 140 os.Exit(1) 141 } 142 143 // abortAt aborts the go_marshal tool with the given error message, with 144 // a reference position to the input source. 145 func abortAt(p token.Position, msg string) { 146 abort(fmt.Sprintf("%v:\n %s\n", p, msg)) 147 } 148 149 // debugf conditionally prints a debug message. 150 func debugf(f string, a ...interface{}) { 151 if debugEnabled() { 152 fmt.Printf(f, a...) 153 } 154 } 155 156 // debugfAt conditionally prints a debug message with a reference to a position 157 // in the input source. 158 func debugfAt(p token.Position, f string, a ...interface{}) { 159 if debugEnabled() { 160 fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...)) 161 } 162 } 163 164 // emit generates a line of code in the output file. 165 // 166 // emit is a wrapper around writing a formatted string to the output 167 // buffer. emit can be invoked in one of two ways: 168 // 169 // (1) emit("some string") 170 // When emit is called with a single string argument, it is simply copied to 171 // the output buffer without any further formatting. 172 // (2) emit(fmtString, args...) 173 // emit can also be invoked in a similar fashion to *Printf() functions, 174 // where the first argument is a format string. 175 // 176 // Calling emit with a single argument that is not a string will result in a 177 // panic, as the caller's intent is ambiguous. 178 func emit(out io.Writer, indent int, a ...interface{}) { 179 const spacesPerIndentLevel = 4 180 181 if len(a) < 1 { 182 panic("emit() called with no arguments") 183 } 184 185 if indent > 0 { 186 if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil { 187 // Writing to the emit output should not fail. Typically the output 188 // is a byte.Buffer; writes to these never fail. 189 panic(err) 190 } 191 } 192 193 first, ok := a[0].(string) 194 if !ok { 195 // First argument must be either the string to emit (case 1 from 196 // function-level comment), or a format string (case 2). 197 panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0])) 198 } 199 200 if len(a) == 1 { 201 // Single string argument. Assume no formatting requested. 202 if _, err := fmt.Fprint(out, first); err != nil { 203 // Writing to out should not fail. 204 panic(err) 205 } 206 return 207 208 } 209 210 // Formatting requested. 211 if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil { 212 // Writing to out should not fail. 213 panic(err) 214 } 215 } 216 217 // sourceBuffer represents fragments of generated go source code. 218 // 219 // sourceBuffer provides a convenient way to build up go souce fragments in 220 // memory. May be safely zero-value initialized. Not thread-safe. 221 type sourceBuffer struct { 222 // Current indentation level. 223 indent int 224 225 // Memory buffer containing contents while they're being generated. 226 b bytes.Buffer 227 } 228 229 func (b *sourceBuffer) reset() { 230 b.indent = 0 231 b.b.Reset() 232 } 233 234 func (b *sourceBuffer) incIndent() { 235 b.indent++ 236 } 237 238 func (b *sourceBuffer) decIndent() { 239 if b.indent <= 0 { 240 panic("decIndent() without matching incIndent()") 241 } 242 b.indent-- 243 } 244 245 func (b *sourceBuffer) emit(a ...interface{}) { 246 emit(&b.b, b.indent, a...) 247 } 248 249 func (b *sourceBuffer) emitNoIndent(a ...interface{}) { 250 emit(&b.b, 0 /*indent*/, a...) 251 } 252 253 func (b *sourceBuffer) inIndent(body func()) { 254 b.incIndent() 255 body() 256 b.decIndent() 257 } 258 259 func (b *sourceBuffer) write(out io.Writer) error { 260 _, err := fmt.Fprint(out, b.b.String()) 261 return err 262 } 263 264 // Write implements io.Writer.Write. 265 func (b *sourceBuffer) Write(buf []byte) (int, error) { 266 return (b.b.Write(buf)) 267 } 268 269 // importStmt represents a single import statement. 270 type importStmt struct { 271 // Local name of the imported package. 272 name string 273 // Import path. 274 path string 275 // Indicates whether the local name is an alias, or simply the final 276 // component of the path. 277 aliased bool 278 // Indicates whether this import was referenced by generated code. 279 used bool 280 // AST node and file set representing the import statement, if any. These 281 // are only non-nil if the import statement originates from an input source 282 // file. 283 spec *ast.ImportSpec 284 fset *token.FileSet 285 } 286 287 func newImport(p string) *importStmt { 288 name := path.Base(p) 289 return &importStmt{ 290 name: name, 291 path: p, 292 aliased: false, 293 } 294 } 295 296 func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { 297 p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path. 298 name := path.Base(p) 299 if name == "" || name == "/" || name == "." { 300 panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)", 301 f.Position(spec.Path.Pos()), name)) 302 } 303 if spec.Name != nil { 304 name = spec.Name.Name 305 } 306 return &importStmt{ 307 name: name, 308 path: p, 309 aliased: spec.Name != nil, 310 spec: spec, 311 fset: f, 312 } 313 } 314 315 // String implements fmt.Stringer.String. This generates a string for the import 316 // statement appropriate for writing directly to generated code. 317 func (i *importStmt) String() string { 318 if i.aliased { 319 return fmt.Sprintf("%s %q", i.name, i.path) 320 } 321 return fmt.Sprintf("%q", i.path) 322 } 323 324 // debugString returns a debug string representing an import statement. This 325 // representation is not valid golang code and is used for debugging output. 326 func (i *importStmt) debugString() string { 327 if i.spec != nil && i.fset != nil { 328 return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i) 329 } 330 return fmt.Sprintf("(go-marshal import): %s", i) 331 } 332 333 func (i *importStmt) markUsed() { 334 i.used = true 335 } 336 337 func (i *importStmt) equivalent(other *importStmt) bool { 338 return i.name == other.name && i.path == other.path && i.aliased == other.aliased 339 } 340 341 // importTable represents a collection of importStmts. 342 // 343 // An importTable may contain multiple import statements referencing the same 344 // local name. All import statements aliasing to the same local name are 345 // technically ambiguous, as if such an import name is used in the generated 346 // code, it's not clear which import statement it refers to. We ignore any 347 // potential collisions until actually writing the import table to the generated 348 // source file. See importTable.write. 349 // 350 // Given the following import statements across all the files comprising a 351 // package marshalled: 352 // 353 // "sync" 354 // "pkg/sync" 355 // "pkg/sentry/kernel" 356 // ktime "pkg/sentry/kernel/time" 357 // 358 // An importTable representing them would look like this: 359 // 360 // importTable { 361 // is: map[string][]*importStmt { 362 // "sync": []*importStmt{ 363 // importStmt{name:"sync", path:"sync", aliased:false} 364 // importStmt{name:"sync", path:"pkg/sync", aliased:false} 365 // }, 366 // "kernel": []*importStmt{importStmt{ 367 // name: "kernel", 368 // path: "pkg/sentry/kernel", 369 // aliased: false 370 // }}, 371 // "ktime": []*importStmt{importStmt{ 372 // name: "ktime", 373 // path: "pkg/sentry/kernel/time", 374 // aliased: true, 375 // }}, 376 // } 377 // } 378 // 379 // Note that the local name "sync" is assigned to two different import 380 // statements. This is possible if the import statements are from different 381 // source files in the same package. 382 // 383 // Since go-marshal generates a single output file per package regardless of the 384 // number of input files, if "sync" is referenced by any generated code, it's 385 // unclear which import statement "sync" refers to. While it's theoretically 386 // possible to resolve this by assigning a unique local alias to each instance 387 // of the sync package, go-marshal currently aborts when it encounters such an 388 // ambiguity. 389 // 390 // TODO(b/151478251): importTable considers the final component of an import 391 // path to be the package name, but this is only a convention. The actual 392 // package name is determined by the package statement in the source files for 393 // the package. 394 type importTable struct { 395 // Map of imports and whether they should be copied to the output. 396 is map[string][]*importStmt 397 } 398 399 func newImportTable() *importTable { 400 return &importTable{ 401 is: make(map[string][]*importStmt), 402 } 403 } 404 405 // Merges import statements from other into i. 406 func (i *importTable) merge(other *importTable) { 407 for name, ims := range other.is { 408 i.is[name] = append(i.is[name], ims...) 409 } 410 } 411 412 func (i *importTable) addStmt(s *importStmt) *importStmt { 413 i.is[s.name] = append(i.is[s.name], s) 414 return s 415 } 416 417 func (i *importTable) add(s string) *importStmt { 418 n := newImport(s) 419 return i.addStmt(n) 420 } 421 422 func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { 423 return i.addStmt(newImportFromSpec(spec, f)) 424 } 425 426 // Marks the import named n as used. If no such import is in the table, returns 427 // false. 428 func (i *importTable) markUsed(n string) bool { 429 if ns, ok := i.is[n]; ok { 430 for _, n := range ns { 431 n.markUsed() 432 } 433 return true 434 } 435 return false 436 } 437 438 func (i *importTable) clear() { 439 for _, is := range i.is { 440 for _, i := range is { 441 i.used = false 442 } 443 } 444 } 445 446 func (i *importTable) write(out io.Writer) error { 447 if len(i.is) == 0 { 448 // Nothing to import, we're done. 449 return nil 450 } 451 452 imports := make([]string, 0, len(i.is)) 453 for name, is := range i.is { 454 var lastUsed *importStmt 455 var ambiguous bool 456 457 for _, i := range is { 458 if i.used { 459 if lastUsed != nil { 460 if !i.equivalent(lastUsed) { 461 ambiguous = true 462 } 463 } 464 lastUsed = i 465 } 466 } 467 468 if ambiguous { 469 // We have two or more import statements across the different source 470 // files that share a local name, and at least one of these imports 471 // are used by the generated code. This ambiguity can't be resolved 472 // by go-marshal and requires the user intervention. Dump a list of 473 // the colliding import statements and let the user modify the input 474 // files as appropriate. 475 var b strings.Builder 476 fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name) 477 fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name) 478 // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't 479 // be true. Therefore the slicing below is safe. 480 for _, i := range is[:len(is)-1] { 481 fmt.Fprintf(&b, " %v\n", i.debugString()) 482 } 483 fmt.Fprintf(&b, " %v", is[len(is)-1].debugString()) 484 panic(b.String()) 485 } 486 487 if lastUsed != nil { 488 imports = append(imports, lastUsed.String()) 489 } 490 } 491 sort.Strings(imports) 492 493 var b sourceBuffer 494 b.emit("import (\n") 495 b.incIndent() 496 for _, i := range imports { 497 b.emit("%s\n", i) 498 } 499 b.decIndent() 500 b.emit(")\n\n") 501 502 return b.write(out) 503 }