github.com/gobwas/gtrace@v0.4.3/cmd/gtrace/main.go (about) 1 package main 2 3 import ( 4 "bufio" 5 "bytes" 6 "flag" 7 "fmt" 8 "go/ast" 9 "go/build" 10 "go/importer" 11 "go/parser" 12 "go/token" 13 "go/types" 14 "io" 15 "log" 16 "os" 17 "path/filepath" 18 "reflect" 19 "strings" 20 "text/tabwriter" 21 22 _ "unsafe" // For go:linkname. 23 ) 24 25 //go:linkname build_goodOSArchFile go/build.(*Context).goodOSArchFile 26 func build_goodOSArchFile(*build.Context, string, map[string]bool) bool 27 28 func main() { 29 var ( 30 verbose bool 31 suffix string 32 stubSuffix string 33 write bool 34 buildTag string 35 ) 36 flag.BoolVar(&verbose, 37 "v", false, 38 "output debug info", 39 ) 40 flag.BoolVar(&write, 41 "w", false, 42 "write trace to file", 43 ) 44 flag.StringVar(&suffix, 45 "file-suffix", "_gtrace", 46 "suffix for generated go files", 47 ) 48 flag.StringVar(&stubSuffix, 49 "stub-file-suffix", "_stub", 50 "suffix for generated stub go files", 51 ) 52 flag.StringVar(&buildTag, 53 "tag", "", 54 "build tag which needs to be passed to enable tracing", 55 ) 56 flag.Parse() 57 58 if verbose { 59 log.SetFlags(log.Lshortfile) 60 } else { 61 log.SetFlags(0) 62 } 63 64 var ( 65 // Reports whether we were called from go:generate. 66 isGoGenerate bool 67 68 gofile string 69 workDir string 70 err error 71 ) 72 if gofile = os.Getenv("GOFILE"); gofile != "" { 73 // NOTE: GOFILE is always a filename without path. 74 isGoGenerate = true 75 workDir, err = os.Getwd() 76 if err != nil { 77 log.Fatal(err) 78 } 79 } else { 80 args := flag.Args() 81 if len(args) == 0 { 82 log.Fatal("no $GOFILE env nor file parameter were given") 83 } 84 gofile = filepath.Base(args[0]) 85 workDir = filepath.Dir(args[0]) 86 } 87 { 88 prefix := filepath.Join(filepath.Base(workDir), gofile) 89 log.SetPrefix("[" + prefix + "] ") 90 } 91 buildCtx := build.Default 92 if verbose { 93 var sb strings.Builder 94 prettyPrint(&sb, buildCtx) 95 log.Printf("build context:\n%s", sb.String()) 96 } 97 buildPkg, err := buildCtx.ImportDir(workDir, build.IgnoreVendor) 98 if err != nil { 99 log.Fatal(err) 100 } 101 102 srcFilePath := filepath.Join(workDir, gofile) 103 if verbose { 104 log.Printf("source file: %s", srcFilePath) 105 log.Printf("package files: %v", buildPkg.GoFiles) 106 } 107 108 var writers []*Writer 109 if isGoGenerate || write { 110 // We should respect Go suffixes like `_linux.go`. 111 name, tags, ext := splitOSArchTags(&buildCtx, gofile) 112 if verbose { 113 log.Printf( 114 "split os/args tags of %q: %q %q %q", 115 gofile, name, tags, ext, 116 ) 117 } 118 openFile := func(name string) (*os.File, func()) { 119 p := filepath.Join(workDir, name) 120 if verbose { 121 log.Printf("destination file path: %+v", p) 122 } 123 f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) 124 if err != nil { 125 log.Fatal(err) 126 } 127 return f, func() { f.Close() } 128 } 129 f, clean := openFile(name + suffix + tags + ext) 130 defer clean() 131 writers = append(writers, &Writer{ 132 Context: buildCtx, 133 Output: f, 134 BuildTag: buildTag, 135 }) 136 if buildTag != "" { 137 f, clean := openFile(name + suffix + stubSuffix + tags + ext) 138 defer clean() 139 writers = append(writers, &Writer{ 140 Context: buildCtx, 141 Output: f, 142 BuildTag: buildTag, 143 Stub: true, 144 }) 145 } 146 } else { 147 writers = append(writers, &Writer{ 148 Context: buildCtx, 149 Output: os.Stdout, 150 BuildTag: buildTag, 151 Stub: true, 152 }) 153 } 154 155 var ( 156 pkgFiles = make([]*os.File, 0, len(buildPkg.GoFiles)) 157 astFiles = make([]*ast.File, 0, len(buildPkg.GoFiles)) 158 159 buildConstraints []string 160 ) 161 fset := token.NewFileSet() 162 for _, name := range buildPkg.GoFiles { 163 base, _, _ := splitOSArchTags(&buildCtx, name) 164 if isGenerated(base, suffix) { 165 // Skip gtrace generated files. 166 if verbose { 167 log.Printf("skipped package file: %q", name) 168 } 169 continue 170 } 171 if verbose { 172 log.Printf("parsing package file: %q", name) 173 } 174 file, err := os.Open(filepath.Join(workDir, name)) 175 if err != nil { 176 log.Fatal(err) 177 } 178 defer file.Close() 179 180 ast, err := parser.ParseFile(fset, file.Name(), file, parser.ParseComments) 181 if err != nil { 182 log.Fatalf("parse %q error: %v", file.Name(), err) 183 } 184 185 pkgFiles = append(pkgFiles, file) 186 astFiles = append(astFiles, ast) 187 188 if name == gofile { 189 if _, err := file.Seek(0, io.SeekStart); err != nil { 190 log.Fatal(err) 191 } 192 buildConstraints, err = scanBuildConstraints(file) 193 if err != nil { 194 log.Fatal(err) 195 } 196 } 197 } 198 info := types.Info{ 199 Types: make(map[ast.Expr]types.TypeAndValue), 200 Defs: make(map[*ast.Ident]types.Object), 201 Uses: make(map[*ast.Ident]types.Object), 202 } 203 conf := types.Config{ 204 IgnoreFuncBodies: true, 205 DisableUnusedImportCheck: true, 206 Importer: importer.ForCompiler(fset, "source", nil), 207 } 208 pkg, err := conf.Check(".", fset, astFiles, &info) 209 if err != nil { 210 log.Fatalf("type error: %v", err) 211 } 212 var items []*GenItem 213 for i, astFile := range astFiles { 214 if pkgFiles[i].Name() != srcFilePath { 215 continue 216 } 217 var ( 218 depth int 219 item *GenItem 220 ) 221 logf := func(s string, args ...interface{}) { 222 if !verbose { 223 return 224 } 225 log.Print( 226 strings.Repeat(" ", depth*4), 227 fmt.Sprintf(s, args...), 228 ) 229 } 230 ast.Inspect(astFile, func(n ast.Node) (next bool) { 231 logf("%T", n) 232 233 if n == nil { 234 item = nil 235 depth-- 236 return true 237 } 238 defer func() { 239 if next { 240 depth++ 241 } 242 }() 243 244 switch v := n.(type) { 245 case 246 *ast.FuncDecl, 247 *ast.ValueSpec: 248 return false 249 250 case *ast.Ident: 251 logf("ident %q", v.Name) 252 if item != nil { 253 item.Ident = v 254 } 255 return false 256 257 case *ast.CommentGroup: 258 for i, c := range v.List { 259 logf("#%d comment %q", i, c.Text) 260 261 text, ok := TrimConfigComment(c.Text) 262 if ok { 263 if item == nil { 264 item = &GenItem{} 265 } 266 if err := item.ParseComment(text); err != nil { 267 log.Fatalf( 268 "malformed comment string: %q: %v", 269 text, err, 270 ) 271 } 272 } 273 } 274 return false 275 276 case *ast.StructType: 277 logf("struct %+v", v) 278 if item != nil { 279 item.StructType = v 280 items = append(items, item) 281 item = nil 282 } 283 return false 284 } 285 286 return true 287 }) 288 } 289 p := Package{ 290 Package: pkg, 291 BuildConstraints: buildConstraints, 292 } 293 traces := make(map[string]*Trace) 294 for _, item := range items { 295 t := &Trace{ 296 Name: item.Ident.Name, 297 Flag: item.Flag, 298 } 299 p.Traces = append(p.Traces, t) 300 traces[item.Ident.Name] = t 301 } 302 for i, item := range items { 303 t := p.Traces[i] 304 for _, field := range item.StructType.Fields.List { 305 name := field.Names[0].Name 306 fn, ok := field.Type.(*ast.FuncType) 307 if !ok { 308 continue 309 } 310 f, err := buildFunc(info, traces, fn) 311 if err != nil { 312 log.Printf( 313 "skipping hook %s due to error: %v", 314 name, err, 315 ) 316 continue 317 } 318 var config GenConfig 319 if doc := field.Doc; doc != nil { 320 for _, line := range doc.List { 321 text, ok := TrimConfigComment(line.Text) 322 if !ok { 323 continue 324 } 325 err := config.ParseComment(text) 326 if err != nil { 327 log.Fatalf( 328 "malformed comment string: %q: %v", 329 text, err, 330 ) 331 } 332 } 333 } 334 t.Hooks = append(t.Hooks, Hook{ 335 Name: name, 336 Func: f, 337 Flag: item.GenConfig.Flag | config.Flag, 338 }) 339 } 340 } 341 for _, w := range writers { 342 if err := w.Write(p); err != nil { 343 log.Fatal(err) 344 } 345 } 346 347 log.Println("OK") 348 } 349 350 func buildFunc(info types.Info, traces map[string]*Trace, fn *ast.FuncType) (ret *Func, err error) { 351 ret = new(Func) 352 for _, p := range fn.Params.List { 353 t := info.TypeOf(p.Type) 354 if t == nil { 355 log.Fatalf("unknown type: %s", p.Type) 356 } 357 var names []string 358 for _, n := range p.Names { 359 name := n.Name 360 if name == "_" { 361 name = "" 362 } 363 names = append(names, name) 364 } 365 if len(names) == 0 { 366 // Case where arg is not named. 367 names = []string{""} 368 } 369 for _, name := range names { 370 ret.Params = append(ret.Params, Param{ 371 Name: name, 372 Type: t, 373 }) 374 } 375 } 376 if fn.Results == nil { 377 return ret, nil 378 } 379 if len(fn.Results.List) > 1 { 380 return nil, fmt.Errorf( 381 "unsupported number of function results", 382 ) 383 } 384 385 r := fn.Results.List[0] 386 387 switch x := r.Type.(type) { 388 case *ast.FuncType: 389 result, err := buildFunc(info, traces, x) 390 if err != nil { 391 return nil, err 392 } 393 ret.Result = append(ret.Result, result) 394 return ret, nil 395 396 case *ast.Ident: 397 if t, ok := traces[x.Name]; ok { 398 t.Nested = true 399 ret.Result = append(ret.Result, t) 400 return ret, nil 401 } 402 } 403 404 return nil, fmt.Errorf( 405 "unsupported function result type %s", 406 info.TypeOf(r.Type), 407 ) 408 } 409 410 func splitOSArchTags(ctx *build.Context, name string) (base, tags, ext string) { 411 fileTags := make(map[string]bool) 412 build_goodOSArchFile(ctx, name, fileTags) 413 ext = filepath.Ext(name) 414 switch len(fileTags) { 415 case 0: // * 416 base = strings.TrimSuffix(name, ext) 417 418 case 1: // *_GOOS or *_GOARCH 419 i := strings.LastIndexByte(name, '_') 420 421 base = name[:i] 422 tags = strings.TrimSuffix(name[i:], ext) 423 424 case 2: // *_GOOS_GOARCH 425 var i int 426 i = strings.LastIndexByte(name, '_') 427 i = strings.LastIndexByte(name[:i], '_') 428 429 base = name[:i] 430 tags = strings.TrimSuffix(name[i:], ext) 431 432 default: 433 panic(fmt.Sprintf( 434 "gtrace: internal error: unexpected number of OS/arch tags: %d", 435 len(fileTags), 436 )) 437 } 438 return 439 } 440 441 type Package struct { 442 *types.Package 443 444 BuildConstraints []string 445 Traces []*Trace 446 } 447 448 type Trace struct { 449 Name string 450 Hooks []Hook 451 Flag GenFlag 452 Nested bool 453 } 454 455 func (*Trace) isFuncResult() bool { return true } 456 457 type Hook struct { 458 Name string 459 Func *Func 460 Flag GenFlag 461 } 462 463 type Param struct { 464 Name string // Might be empty. 465 Type types.Type 466 } 467 468 type FuncResult interface { 469 isFuncResult() bool 470 } 471 472 type Func struct { 473 Params []Param 474 Result []FuncResult // 0 or 1. 475 } 476 477 func (*Func) isFuncResult() bool { return true } 478 479 func (f *Func) HasResult() bool { 480 return len(f.Result) > 0 481 } 482 483 type GenFlag uint8 484 485 func (f GenFlag) Has(x GenFlag) bool { 486 return f&x != 0 487 } 488 489 const ( 490 GenZero GenFlag = 1 << iota >> 1 491 GenShortcut 492 GenContext 493 494 GenAll = ^GenFlag(0) 495 ) 496 497 type GenConfig struct { 498 Flag GenFlag 499 } 500 501 func TrimConfigComment(text string) (string, bool) { 502 s := strings.TrimPrefix(text, "//gtrace:") 503 if text != s { 504 return s, true 505 } 506 return "", false 507 } 508 509 func (g *GenConfig) ParseComment(text string) (err error) { 510 prefix, text := split(text, ' ') 511 switch prefix { 512 case "gen": 513 case "set": 514 return g.ParseParameter(text) 515 default: 516 return fmt.Errorf("unknown prefix: %q", prefix) 517 } 518 return nil 519 } 520 521 func (g *GenConfig) ParseParameter(text string) (err error) { 522 text = strings.TrimSpace(text) 523 param, _ := split(text, '=') 524 if param == "" { 525 return nil 526 } 527 switch param { 528 case "shortcut": 529 g.Flag |= GenShortcut 530 case "context": 531 g.Flag |= GenContext 532 default: 533 return fmt.Errorf("unexpected parameter: %q", param) 534 } 535 return nil 536 } 537 538 type GenItem struct { 539 GenConfig 540 Ident *ast.Ident 541 StructType *ast.StructType 542 } 543 544 func split(s string, c byte) (s1, s2 string) { 545 i := strings.IndexByte(s, c) 546 if i == -1 { 547 return s, "" 548 } 549 return s[:i], s[i+1:] 550 } 551 552 func rsplit(s string, c byte) (s1, s2 string) { 553 i := strings.LastIndexByte(s, c) 554 if i == -1 { 555 return s, "" 556 } 557 return s[:i], s[i+1:] 558 } 559 560 func scanBuildConstraints(r io.Reader) (cs []string, err error) { 561 br := bufio.NewReader(r) 562 for { 563 line, err := br.ReadBytes('\n') 564 if err != nil { 565 return nil, err 566 } 567 line = bytes.TrimSpace(line) 568 if comm := bytes.TrimPrefix(line, []byte("//")); !bytes.Equal(comm, line) { 569 comm = bytes.TrimSpace(comm) 570 if bytes.HasPrefix(comm, []byte("+build")) { 571 cs = append(cs, string(line)) 572 continue 573 } 574 } 575 if bytes.HasPrefix(line, []byte("package ")) { 576 break 577 } 578 } 579 return cs, nil 580 } 581 582 func prettyPrint(w io.Writer, x interface{}) { 583 tw := tabwriter.NewWriter(w, 0, 2, 2, ' ', 0) 584 t := reflect.TypeOf(x) 585 v := reflect.ValueOf(x) 586 for i := 0; i < t.NumField(); i++ { 587 if v.Field(i).IsZero() { 588 continue 589 } 590 fmt.Fprintf(tw, "%s:\t%v\n", 591 t.Field(i).Name, 592 v.Field(i), 593 ) 594 } 595 tw.Flush() 596 } 597 598 func isGenerated(base, suffix string) bool { 599 i := strings.Index(base, suffix) 600 if i == -1 { 601 return false 602 } 603 n := len(base) 604 m := i + len(suffix) 605 return m == n || base[m] == '_' 606 }