github.com/brownsys/tracing-framework-go@v0.0.0-20161210174012-0542a62412fe/other/cmd/instrument/backup/main.go (about)

     1  // +build !instrument
     2  
     3  package main
     4  
     5  import (
     6  	"bytes"
     7  	"fmt"
     8  	"go/ast"
     9  	"go/format"
    10  	"go/parser"
    11  	"go/printer"
    12  	"go/token"
    13  	"io/ioutil"
    14  	"os"
    15  	"text/template"
    16  )
    17  
    18  func main() {
    19  	processDir(".", func(f *ast.FuncDecl) bool { return true })
    20  }
    21  
    22  func processDir(path string, filter func(*ast.FuncDecl) bool) {
    23  	fs := token.NewFileSet()
    24  	pkgs, err := parser.ParseDir(fs, path, nil, parser.ParseComments|parser.DeclarationErrors)
    25  	if err != nil {
    26  		fmt.Fprintf(os.Stderr, "could not parse package: %v\n", err)
    27  		os.Exit(2)
    28  	}
    29  
    30  	if len(pkgs) > 2 {
    31  		fmt.Fprintln(os.Stderr, "found multiple packages")
    32  		os.Exit(2)
    33  	}
    34  
    35  	if len(pkgs) == 0 {
    36  		os.Exit(0)
    37  	}
    38  
    39  	type tmplEntry struct {
    40  		Fname string
    41  		Flag  string
    42  		Args  []string
    43  	}
    44  
    45  	var entries []tmplEntry
    46  	var pkgname string
    47  	var pkg *ast.Package
    48  
    49  	for name, p := range pkgs {
    50  		pkgname = name
    51  		pkg = p
    52  	}
    53  
    54  	for fname, file := range pkg.Files {
    55  		_ = fname
    56  		for _, fnctmp := range file.Decls {
    57  			fnc, ok := fnctmp.(*ast.FuncDecl)
    58  			if !ok || !filter(fnc) {
    59  				continue
    60  			}
    61  			entry := tmplEntry{
    62  				Fname: fnc.Name.String(),
    63  				Flag:  "__instrument_" + fnc.Name.String(),
    64  			}
    65  			for _, arg := range fnc.Type.Params.List {
    66  				for _, name := range arg.Names {
    67  					entry.Args = append(entry.Args, name.Name)
    68  				}
    69  			}
    70  			entries = append(entries, entry)
    71  			var buf bytes.Buffer
    72  			err = shimTmpl.Execute(&buf, entry)
    73  			if err != nil {
    74  				fmt.Fprintf(os.Stderr, "unexpected internal error: %v\n", err)
    75  				os.Exit(3)
    76  			}
    77  			stmt := parseStmt(string(buf.Bytes()))
    78  			if len(stmt.List) != 1 {
    79  				panic("internal error")
    80  			}
    81  			fnc.Body.List = append([]ast.Stmt{stmt.List[0]}, fnc.Body.List...)
    82  		}
    83  
    84  		origHasBuildTag := false
    85  
    86  		for _, c := range file.Comments {
    87  			for _, c := range c.List {
    88  				// fmt.Println(c.Text)
    89  				if c.Text == "// +build !instrument" {
    90  					c.Text = "// +build instrument"
    91  					origHasBuildTag = true
    92  				}
    93  			}
    94  		}
    95  
    96  		var buf bytes.Buffer
    97  		if origHasBuildTag {
    98  			printer.Fprint(&buf, fs, file)
    99  		} else {
   100  			buf.Write([]byte("// +build instrument\n\n"))
   101  			printer.Fprint(&buf, fs, file)
   102  
   103  			// prepend build comment to original file
   104  			b, err := ioutil.ReadFile(fname)
   105  			if err != nil {
   106  				fmt.Fprintf(os.Stderr, "could not read source file: %v\n", err)
   107  				os.Exit(2)
   108  			}
   109  			b = append([]byte("// +build !instrument\n\n"), b...)
   110  			b, err = format.Source(b)
   111  			if err != nil {
   112  				fmt.Fprintf(os.Stderr, "could not format source file %v: %v\n", fname, err)
   113  				os.Exit(2)
   114  			}
   115  			f, err := os.OpenFile(fname, os.O_WRONLY, 0)
   116  			if err != nil {
   117  				fmt.Fprintf(os.Stderr, "could not open source file for writing: %v\n", err)
   118  				os.Exit(2)
   119  			}
   120  			if _, err = f.Write(b); err != nil {
   121  				fmt.Fprintf(os.Stderr, "could not write to source file: %v\n", err)
   122  				os.Exit(2)
   123  			}
   124  		}
   125  
   126  		b, err := format.Source(buf.Bytes())
   127  		if err != nil {
   128  			fmt.Fprintf(os.Stderr, "unexpected internal error: %v\n", err)
   129  			os.Exit(3)
   130  		}
   131  		os.Stdout.Write(b)
   132  	}
   133  
   134  	// fmt.Println("=======")
   135  
   136  	var buf bytes.Buffer
   137  	err = initTmpl.Execute(&buf, entries)
   138  	if err != nil {
   139  		fmt.Fprintf(os.Stderr, "unexpected internal error: %v\n", err)
   140  		os.Exit(3)
   141  	}
   142  
   143  	b, err := format.Source([]byte("package " + pkgname + string(buf.Bytes())))
   144  	if err != nil {
   145  		fmt.Fprintf(os.Stderr, "unexpected internal error: %v\n", err)
   146  		os.Exit(3)
   147  	}
   148  	_ = b
   149  	// os.Stdout.Write(b)
   150  }
   151  
   152  func parseStmt(src string) *ast.BlockStmt {
   153  	src = `package main
   154  	func a() {` + src + `}`
   155  	fset := token.NewFileSet()
   156  	a, err := parser.ParseFile(fset, "", src, parser.ParseComments|parser.DeclarationErrors)
   157  	if err != nil {
   158  		panic(fmt.Errorf("internal error: %v", err))
   159  	}
   160  	return a.Decls[0].(*ast.FuncDecl).Body
   161  }
   162  
   163  var initTmpl *template.Template = template.Must(template.New("").Parse(`
   164  import "local/research/instrument"
   165  
   166  var (
   167  	{{range .}}{{.Flag}} bool
   168  {{end}})
   169  
   170  func init() {
   171  	{{range .}}instrument.RegisterFlag({{.Fname}}, &{{.Flag}})
   172  {{end}}}
   173  `))
   174  
   175  var shimTmpl = template.Must(template.New("").Parse(`
   176  if {{.Flag}} {
   177  	callback, ok := instrument.GetCallback({{.Fname}})
   178  	if ok {
   179  		callback({{range .Args}}{{.}},{{end}})
   180  	}
   181  }
   182  `))