github.com/gobwas/gtrace@v0.4.3/cmd/gtrace/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"flag"
     7  	"fmt"
     8  	"go/ast"
     9  	"go/build"
    10  	"go/importer"
    11  	"go/parser"
    12  	"go/token"
    13  	"go/types"
    14  	"io"
    15  	"log"
    16  	"os"
    17  	"path/filepath"
    18  	"reflect"
    19  	"strings"
    20  	"text/tabwriter"
    21  
    22  	_ "unsafe" // For go:linkname.
    23  )
    24  
    25  //go:linkname build_goodOSArchFile go/build.(*Context).goodOSArchFile
    26  func build_goodOSArchFile(*build.Context, string, map[string]bool) bool
    27  
    28  func main() {
    29  	var (
    30  		verbose    bool
    31  		suffix     string
    32  		stubSuffix string
    33  		write      bool
    34  		buildTag   string
    35  	)
    36  	flag.BoolVar(&verbose,
    37  		"v", false,
    38  		"output debug info",
    39  	)
    40  	flag.BoolVar(&write,
    41  		"w", false,
    42  		"write trace to file",
    43  	)
    44  	flag.StringVar(&suffix,
    45  		"file-suffix", "_gtrace",
    46  		"suffix for generated go files",
    47  	)
    48  	flag.StringVar(&stubSuffix,
    49  		"stub-file-suffix", "_stub",
    50  		"suffix for generated stub go files",
    51  	)
    52  	flag.StringVar(&buildTag,
    53  		"tag", "",
    54  		"build tag which needs to be passed to enable tracing",
    55  	)
    56  	flag.Parse()
    57  
    58  	if verbose {
    59  		log.SetFlags(log.Lshortfile)
    60  	} else {
    61  		log.SetFlags(0)
    62  	}
    63  
    64  	var (
    65  		// Reports whether we were called from go:generate.
    66  		isGoGenerate bool
    67  
    68  		gofile  string
    69  		workDir string
    70  		err     error
    71  	)
    72  	if gofile = os.Getenv("GOFILE"); gofile != "" {
    73  		// NOTE: GOFILE is always a filename without path.
    74  		isGoGenerate = true
    75  		workDir, err = os.Getwd()
    76  		if err != nil {
    77  			log.Fatal(err)
    78  		}
    79  	} else {
    80  		args := flag.Args()
    81  		if len(args) == 0 {
    82  			log.Fatal("no $GOFILE env nor file parameter were given")
    83  		}
    84  		gofile = filepath.Base(args[0])
    85  		workDir = filepath.Dir(args[0])
    86  	}
    87  	{
    88  		prefix := filepath.Join(filepath.Base(workDir), gofile)
    89  		log.SetPrefix("[" + prefix + "] ")
    90  	}
    91  	buildCtx := build.Default
    92  	if verbose {
    93  		var sb strings.Builder
    94  		prettyPrint(&sb, buildCtx)
    95  		log.Printf("build context:\n%s", sb.String())
    96  	}
    97  	buildPkg, err := buildCtx.ImportDir(workDir, build.IgnoreVendor)
    98  	if err != nil {
    99  		log.Fatal(err)
   100  	}
   101  
   102  	srcFilePath := filepath.Join(workDir, gofile)
   103  	if verbose {
   104  		log.Printf("source file: %s", srcFilePath)
   105  		log.Printf("package files: %v", buildPkg.GoFiles)
   106  	}
   107  
   108  	var writers []*Writer
   109  	if isGoGenerate || write {
   110  		// We should respect Go suffixes like `_linux.go`.
   111  		name, tags, ext := splitOSArchTags(&buildCtx, gofile)
   112  		if verbose {
   113  			log.Printf(
   114  				"split os/args tags of %q: %q %q %q",
   115  				gofile, name, tags, ext,
   116  			)
   117  		}
   118  		openFile := func(name string) (*os.File, func()) {
   119  			p := filepath.Join(workDir, name)
   120  			if verbose {
   121  				log.Printf("destination file path: %+v", p)
   122  			}
   123  			f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
   124  			if err != nil {
   125  				log.Fatal(err)
   126  			}
   127  			return f, func() { f.Close() }
   128  		}
   129  		f, clean := openFile(name + suffix + tags + ext)
   130  		defer clean()
   131  		writers = append(writers, &Writer{
   132  			Context:  buildCtx,
   133  			Output:   f,
   134  			BuildTag: buildTag,
   135  		})
   136  		if buildTag != "" {
   137  			f, clean := openFile(name + suffix + stubSuffix + tags + ext)
   138  			defer clean()
   139  			writers = append(writers, &Writer{
   140  				Context:  buildCtx,
   141  				Output:   f,
   142  				BuildTag: buildTag,
   143  				Stub:     true,
   144  			})
   145  		}
   146  	} else {
   147  		writers = append(writers, &Writer{
   148  			Context:  buildCtx,
   149  			Output:   os.Stdout,
   150  			BuildTag: buildTag,
   151  			Stub:     true,
   152  		})
   153  	}
   154  
   155  	var (
   156  		pkgFiles = make([]*os.File, 0, len(buildPkg.GoFiles))
   157  		astFiles = make([]*ast.File, 0, len(buildPkg.GoFiles))
   158  
   159  		buildConstraints []string
   160  	)
   161  	fset := token.NewFileSet()
   162  	for _, name := range buildPkg.GoFiles {
   163  		base, _, _ := splitOSArchTags(&buildCtx, name)
   164  		if isGenerated(base, suffix) {
   165  			// Skip gtrace generated files.
   166  			if verbose {
   167  				log.Printf("skipped package file: %q", name)
   168  			}
   169  			continue
   170  		}
   171  		if verbose {
   172  			log.Printf("parsing package file: %q", name)
   173  		}
   174  		file, err := os.Open(filepath.Join(workDir, name))
   175  		if err != nil {
   176  			log.Fatal(err)
   177  		}
   178  		defer file.Close()
   179  
   180  		ast, err := parser.ParseFile(fset, file.Name(), file, parser.ParseComments)
   181  		if err != nil {
   182  			log.Fatalf("parse %q error: %v", file.Name(), err)
   183  		}
   184  
   185  		pkgFiles = append(pkgFiles, file)
   186  		astFiles = append(astFiles, ast)
   187  
   188  		if name == gofile {
   189  			if _, err := file.Seek(0, io.SeekStart); err != nil {
   190  				log.Fatal(err)
   191  			}
   192  			buildConstraints, err = scanBuildConstraints(file)
   193  			if err != nil {
   194  				log.Fatal(err)
   195  			}
   196  		}
   197  	}
   198  	info := types.Info{
   199  		Types: make(map[ast.Expr]types.TypeAndValue),
   200  		Defs:  make(map[*ast.Ident]types.Object),
   201  		Uses:  make(map[*ast.Ident]types.Object),
   202  	}
   203  	conf := types.Config{
   204  		IgnoreFuncBodies:         true,
   205  		DisableUnusedImportCheck: true,
   206  		Importer:                 importer.ForCompiler(fset, "source", nil),
   207  	}
   208  	pkg, err := conf.Check(".", fset, astFiles, &info)
   209  	if err != nil {
   210  		log.Fatalf("type error: %v", err)
   211  	}
   212  	var items []*GenItem
   213  	for i, astFile := range astFiles {
   214  		if pkgFiles[i].Name() != srcFilePath {
   215  			continue
   216  		}
   217  		var (
   218  			depth int
   219  			item  *GenItem
   220  		)
   221  		logf := func(s string, args ...interface{}) {
   222  			if !verbose {
   223  				return
   224  			}
   225  			log.Print(
   226  				strings.Repeat(" ", depth*4),
   227  				fmt.Sprintf(s, args...),
   228  			)
   229  		}
   230  		ast.Inspect(astFile, func(n ast.Node) (next bool) {
   231  			logf("%T", n)
   232  
   233  			if n == nil {
   234  				item = nil
   235  				depth--
   236  				return true
   237  			}
   238  			defer func() {
   239  				if next {
   240  					depth++
   241  				}
   242  			}()
   243  
   244  			switch v := n.(type) {
   245  			case
   246  				*ast.FuncDecl,
   247  				*ast.ValueSpec:
   248  				return false
   249  
   250  			case *ast.Ident:
   251  				logf("ident %q", v.Name)
   252  				if item != nil {
   253  					item.Ident = v
   254  				}
   255  				return false
   256  
   257  			case *ast.CommentGroup:
   258  				for i, c := range v.List {
   259  					logf("#%d comment %q", i, c.Text)
   260  
   261  					text, ok := TrimConfigComment(c.Text)
   262  					if ok {
   263  						if item == nil {
   264  							item = &GenItem{}
   265  						}
   266  						if err := item.ParseComment(text); err != nil {
   267  							log.Fatalf(
   268  								"malformed comment string: %q: %v",
   269  								text, err,
   270  							)
   271  						}
   272  					}
   273  				}
   274  				return false
   275  
   276  			case *ast.StructType:
   277  				logf("struct %+v", v)
   278  				if item != nil {
   279  					item.StructType = v
   280  					items = append(items, item)
   281  					item = nil
   282  				}
   283  				return false
   284  			}
   285  
   286  			return true
   287  		})
   288  	}
   289  	p := Package{
   290  		Package:          pkg,
   291  		BuildConstraints: buildConstraints,
   292  	}
   293  	traces := make(map[string]*Trace)
   294  	for _, item := range items {
   295  		t := &Trace{
   296  			Name: item.Ident.Name,
   297  			Flag: item.Flag,
   298  		}
   299  		p.Traces = append(p.Traces, t)
   300  		traces[item.Ident.Name] = t
   301  	}
   302  	for i, item := range items {
   303  		t := p.Traces[i]
   304  		for _, field := range item.StructType.Fields.List {
   305  			name := field.Names[0].Name
   306  			fn, ok := field.Type.(*ast.FuncType)
   307  			if !ok {
   308  				continue
   309  			}
   310  			f, err := buildFunc(info, traces, fn)
   311  			if err != nil {
   312  				log.Printf(
   313  					"skipping hook %s due to error: %v",
   314  					name, err,
   315  				)
   316  				continue
   317  			}
   318  			var config GenConfig
   319  			if doc := field.Doc; doc != nil {
   320  				for _, line := range doc.List {
   321  					text, ok := TrimConfigComment(line.Text)
   322  					if !ok {
   323  						continue
   324  					}
   325  					err := config.ParseComment(text)
   326  					if err != nil {
   327  						log.Fatalf(
   328  							"malformed comment string: %q: %v",
   329  							text, err,
   330  						)
   331  					}
   332  				}
   333  			}
   334  			t.Hooks = append(t.Hooks, Hook{
   335  				Name: name,
   336  				Func: f,
   337  				Flag: item.GenConfig.Flag | config.Flag,
   338  			})
   339  		}
   340  	}
   341  	for _, w := range writers {
   342  		if err := w.Write(p); err != nil {
   343  			log.Fatal(err)
   344  		}
   345  	}
   346  
   347  	log.Println("OK")
   348  }
   349  
   350  func buildFunc(info types.Info, traces map[string]*Trace, fn *ast.FuncType) (ret *Func, err error) {
   351  	ret = new(Func)
   352  	for _, p := range fn.Params.List {
   353  		t := info.TypeOf(p.Type)
   354  		if t == nil {
   355  			log.Fatalf("unknown type: %s", p.Type)
   356  		}
   357  		var names []string
   358  		for _, n := range p.Names {
   359  			name := n.Name
   360  			if name == "_" {
   361  				name = ""
   362  			}
   363  			names = append(names, name)
   364  		}
   365  		if len(names) == 0 {
   366  			// Case where arg is not named.
   367  			names = []string{""}
   368  		}
   369  		for _, name := range names {
   370  			ret.Params = append(ret.Params, Param{
   371  				Name: name,
   372  				Type: t,
   373  			})
   374  		}
   375  	}
   376  	if fn.Results == nil {
   377  		return ret, nil
   378  	}
   379  	if len(fn.Results.List) > 1 {
   380  		return nil, fmt.Errorf(
   381  			"unsupported number of function results",
   382  		)
   383  	}
   384  
   385  	r := fn.Results.List[0]
   386  
   387  	switch x := r.Type.(type) {
   388  	case *ast.FuncType:
   389  		result, err := buildFunc(info, traces, x)
   390  		if err != nil {
   391  			return nil, err
   392  		}
   393  		ret.Result = append(ret.Result, result)
   394  		return ret, nil
   395  
   396  	case *ast.Ident:
   397  		if t, ok := traces[x.Name]; ok {
   398  			t.Nested = true
   399  			ret.Result = append(ret.Result, t)
   400  			return ret, nil
   401  		}
   402  	}
   403  
   404  	return nil, fmt.Errorf(
   405  		"unsupported function result type %s",
   406  		info.TypeOf(r.Type),
   407  	)
   408  }
   409  
   410  func splitOSArchTags(ctx *build.Context, name string) (base, tags, ext string) {
   411  	fileTags := make(map[string]bool)
   412  	build_goodOSArchFile(ctx, name, fileTags)
   413  	ext = filepath.Ext(name)
   414  	switch len(fileTags) {
   415  	case 0: // *
   416  		base = strings.TrimSuffix(name, ext)
   417  
   418  	case 1: // *_GOOS or *_GOARCH
   419  		i := strings.LastIndexByte(name, '_')
   420  
   421  		base = name[:i]
   422  		tags = strings.TrimSuffix(name[i:], ext)
   423  
   424  	case 2: // *_GOOS_GOARCH
   425  		var i int
   426  		i = strings.LastIndexByte(name, '_')
   427  		i = strings.LastIndexByte(name[:i], '_')
   428  
   429  		base = name[:i]
   430  		tags = strings.TrimSuffix(name[i:], ext)
   431  
   432  	default:
   433  		panic(fmt.Sprintf(
   434  			"gtrace: internal error: unexpected number of OS/arch tags: %d",
   435  			len(fileTags),
   436  		))
   437  	}
   438  	return
   439  }
   440  
   441  type Package struct {
   442  	*types.Package
   443  
   444  	BuildConstraints []string
   445  	Traces           []*Trace
   446  }
   447  
   448  type Trace struct {
   449  	Name   string
   450  	Hooks  []Hook
   451  	Flag   GenFlag
   452  	Nested bool
   453  }
   454  
   455  func (*Trace) isFuncResult() bool { return true }
   456  
   457  type Hook struct {
   458  	Name string
   459  	Func *Func
   460  	Flag GenFlag
   461  }
   462  
   463  type Param struct {
   464  	Name string // Might be empty.
   465  	Type types.Type
   466  }
   467  
   468  type FuncResult interface {
   469  	isFuncResult() bool
   470  }
   471  
   472  type Func struct {
   473  	Params []Param
   474  	Result []FuncResult // 0 or 1.
   475  }
   476  
   477  func (*Func) isFuncResult() bool { return true }
   478  
   479  func (f *Func) HasResult() bool {
   480  	return len(f.Result) > 0
   481  }
   482  
   483  type GenFlag uint8
   484  
   485  func (f GenFlag) Has(x GenFlag) bool {
   486  	return f&x != 0
   487  }
   488  
   489  const (
   490  	GenZero GenFlag = 1 << iota >> 1
   491  	GenShortcut
   492  	GenContext
   493  
   494  	GenAll = ^GenFlag(0)
   495  )
   496  
   497  type GenConfig struct {
   498  	Flag GenFlag
   499  }
   500  
   501  func TrimConfigComment(text string) (string, bool) {
   502  	s := strings.TrimPrefix(text, "//gtrace:")
   503  	if text != s {
   504  		return s, true
   505  	}
   506  	return "", false
   507  }
   508  
   509  func (g *GenConfig) ParseComment(text string) (err error) {
   510  	prefix, text := split(text, ' ')
   511  	switch prefix {
   512  	case "gen":
   513  	case "set":
   514  		return g.ParseParameter(text)
   515  	default:
   516  		return fmt.Errorf("unknown prefix: %q", prefix)
   517  	}
   518  	return nil
   519  }
   520  
   521  func (g *GenConfig) ParseParameter(text string) (err error) {
   522  	text = strings.TrimSpace(text)
   523  	param, _ := split(text, '=')
   524  	if param == "" {
   525  		return nil
   526  	}
   527  	switch param {
   528  	case "shortcut":
   529  		g.Flag |= GenShortcut
   530  	case "context":
   531  		g.Flag |= GenContext
   532  	default:
   533  		return fmt.Errorf("unexpected parameter: %q", param)
   534  	}
   535  	return nil
   536  }
   537  
   538  type GenItem struct {
   539  	GenConfig
   540  	Ident      *ast.Ident
   541  	StructType *ast.StructType
   542  }
   543  
   544  func split(s string, c byte) (s1, s2 string) {
   545  	i := strings.IndexByte(s, c)
   546  	if i == -1 {
   547  		return s, ""
   548  	}
   549  	return s[:i], s[i+1:]
   550  }
   551  
   552  func rsplit(s string, c byte) (s1, s2 string) {
   553  	i := strings.LastIndexByte(s, c)
   554  	if i == -1 {
   555  		return s, ""
   556  	}
   557  	return s[:i], s[i+1:]
   558  }
   559  
   560  func scanBuildConstraints(r io.Reader) (cs []string, err error) {
   561  	br := bufio.NewReader(r)
   562  	for {
   563  		line, err := br.ReadBytes('\n')
   564  		if err != nil {
   565  			return nil, err
   566  		}
   567  		line = bytes.TrimSpace(line)
   568  		if comm := bytes.TrimPrefix(line, []byte("//")); !bytes.Equal(comm, line) {
   569  			comm = bytes.TrimSpace(comm)
   570  			if bytes.HasPrefix(comm, []byte("+build")) {
   571  				cs = append(cs, string(line))
   572  				continue
   573  			}
   574  		}
   575  		if bytes.HasPrefix(line, []byte("package ")) {
   576  			break
   577  		}
   578  	}
   579  	return cs, nil
   580  }
   581  
   582  func prettyPrint(w io.Writer, x interface{}) {
   583  	tw := tabwriter.NewWriter(w, 0, 2, 2, ' ', 0)
   584  	t := reflect.TypeOf(x)
   585  	v := reflect.ValueOf(x)
   586  	for i := 0; i < t.NumField(); i++ {
   587  		if v.Field(i).IsZero() {
   588  			continue
   589  		}
   590  		fmt.Fprintf(tw, "%s:\t%v\n",
   591  			t.Field(i).Name,
   592  			v.Field(i),
   593  		)
   594  	}
   595  	tw.Flush()
   596  }
   597  
   598  func isGenerated(base, suffix string) bool {
   599  	i := strings.Index(base, suffix)
   600  	if i == -1 {
   601  		return false
   602  	}
   603  	n := len(base)
   604  	m := i + len(suffix)
   605  	return m == n || base[m] == '_'
   606  }