github.com/brownsys/tracing-framework-go@v0.0.0-20161210174012-0542a62412fe/other/cmd/instrument/main.go (about)

     1  // +build !instrument
     2  
     3  package main
     4  
     5  import (
     6  	"bytes"
     7  	"fmt"
     8  	"go/ast"
     9  	"go/format"
    10  	"go/parser"
    11  	"go/printer"
    12  	"go/token"
    13  	"io/ioutil"
    14  	"os"
    15  	"path/filepath"
    16  	"reflect"
    17  	"strings"
    18  	"text/template"
    19  
    20  	"golang.org/x/tools/go/ast/astutil"
    21  )
    22  
    23  const (
    24  	helperFile = "instrument_helper.go"
    25  	fileSuffix = "_instrument.go"
    26  	buildTag   = "instrumente"
    27  )
    28  
    29  func main() {
    30  	processDir(".", func(f *ast.FuncDecl) bool { return f.Name.Name != "init" })
    31  }
    32  
    33  type dummyArg struct {
    34  	Name, Type string
    35  }
    36  
    37  type tmplEntry struct {
    38  	Fname        string
    39  	Flag         string
    40  	CallbackType string
    41  	Args         []string
    42  	DummyArgs    []dummyArg
    43  }
    44  
    45  func processDir(path string, filter func(*ast.FuncDecl) bool) {
    46  	fs := token.NewFileSet()
    47  	parseFilter := func(fi os.FileInfo) bool {
    48  		if fi.Name() == helperFile ||
    49  			strings.HasSuffix(fi.Name(), fileSuffix) ||
    50  			strings.HasSuffix(fi.Name(), "_test.go") {
    51  			return false
    52  		}
    53  		return true
    54  	}
    55  	pkgs, err := parser.ParseDir(fs, path, parseFilter, parser.ParseComments|parser.DeclarationErrors)
    56  	if err != nil {
    57  		fmt.Fprintf(os.Stderr, "could not parse package: %v\n", err)
    58  		os.Exit(2)
    59  	}
    60  
    61  	if len(pkgs) > 2 {
    62  		fmt.Fprintln(os.Stderr, "found multiple packages")
    63  		os.Exit(2)
    64  	}
    65  
    66  	if len(pkgs) == 0 {
    67  		os.Exit(0)
    68  	}
    69  
    70  	var entries []tmplEntry
    71  	var pkgname string
    72  	var pkg *ast.Package
    73  
    74  	for name, p := range pkgs {
    75  		pkgname = name
    76  		pkg = p
    77  	}
    78  
    79  	for fname, file := range pkg.Files {
    80  		_ = fname
    81  		for _, fnctmp := range file.Decls {
    82  			fnc, ok := fnctmp.(*ast.FuncDecl)
    83  			if !ok || !filter(fnc) {
    84  				continue
    85  			}
    86  			entry := funcToEntry(fs, fnc)
    87  			entries = append(entries, entry)
    88  			var buf bytes.Buffer
    89  			err := shimTmpl.Execute(&buf, entry)
    90  			if err != nil {
    91  				panic(fmt.Errorf("unexpected internal error: %v", err))
    92  			}
    93  			stmt := parseStmt(string(buf.Bytes()))
    94  			if len(stmt.List) != 1 {
    95  				panic("internal error")
    96  			}
    97  			fnc.Body.List = append([]ast.Stmt{stmt.List[0]}, fnc.Body.List...)
    98  		}
    99  
   100  		astutil.AddImport(fs, file, "github.com/brownsys/tracing-framework-go/instrument")
   101  
   102  		origHasBuildTag := false
   103  
   104  		for _, c := range file.Comments {
   105  			for _, c := range c.List {
   106  				if c.Text == "// +build !"+buildTag {
   107  					c.Text = "// +build " + buildTag
   108  					origHasBuildTag = true
   109  				}
   110  			}
   111  		}
   112  
   113  		var buf bytes.Buffer
   114  		if origHasBuildTag {
   115  			printer.Fprint(&buf, fs, file)
   116  		} else {
   117  			buf.Write([]byte("// +build " + buildTag + "\n\n"))
   118  			printer.Fprint(&buf, fs, file)
   119  
   120  			// prepend build comment to original file
   121  			b, err := ioutil.ReadFile(fname)
   122  			if err != nil {
   123  				fmt.Fprintf(os.Stderr, "could not read source file: %v\n", err)
   124  				os.Exit(2)
   125  			}
   126  			b = append([]byte("// +build !"+buildTag+"\n\n"), b...)
   127  			b, err = format.Source(b)
   128  			if err != nil {
   129  				fmt.Fprintf(os.Stderr, "could not format source file %v: %v\n", fname, err)
   130  				os.Exit(2)
   131  			}
   132  			f, err := os.OpenFile(filepath.Join(path, fname), os.O_WRONLY, 0)
   133  			if err != nil {
   134  				fmt.Fprintf(os.Stderr, "could not open source file for writing: %v\n", err)
   135  				os.Exit(2)
   136  			}
   137  			if _, err = f.Write(b); err != nil {
   138  				fmt.Fprintf(os.Stderr, "could not write to source file: %v\n", err)
   139  				os.Exit(2)
   140  			}
   141  		}
   142  
   143  		b, err := format.Source(buf.Bytes())
   144  		if err != nil {
   145  			panic(fmt.Errorf("unexpected internal error: %v", err))
   146  		}
   147  		fpath := filepath.Join(path, fname[:len(fname)-3]+fileSuffix)
   148  		if err = ioutil.WriteFile(fpath, b, 0664); err != nil {
   149  			fmt.Fprintf(os.Stderr, "could not create instrument source file: %v\n", err)
   150  			os.Exit(2)
   151  		}
   152  	}
   153  
   154  	// create a new slice of entries, this time
   155  	// deduplicated (in case the same functions
   156  	// appear multiple times across files with
   157  	// different build constraints)
   158  	seenEntries := make(map[string]bool)
   159  	var newEntries []tmplEntry
   160  	for _, e := range entries {
   161  		if seenEntries[e.Fname] {
   162  			continue
   163  		}
   164  		seenEntries[e.Fname] = true
   165  		newEntries = append(newEntries, e)
   166  	}
   167  
   168  	var buf bytes.Buffer
   169  	err = initTmpl.Execute(&buf, newEntries)
   170  	if err != nil {
   171  		panic(fmt.Errorf("unexpected internal error: %v", err))
   172  	}
   173  
   174  	newbody := `// +build ` + buildTag + `
   175  	
   176  package ` + pkgname + string(buf.Bytes())
   177  
   178  	b, err := format.Source([]byte(newbody))
   179  	if err != nil {
   180  		panic(fmt.Errorf("unexpected internal error: %v", err))
   181  	}
   182  	if err = ioutil.WriteFile(filepath.Join(path, helperFile), b, 0664); err != nil {
   183  		fmt.Fprintf(os.Stderr, "could not create %v: %v\n", helperFile, err)
   184  		os.Exit(2)
   185  	}
   186  }
   187  
   188  func funcToEntry(fs *token.FileSet, f *ast.FuncDecl) tmplEntry {
   189  	// NOTE: throughout this function, it's important
   190  	// that we don't modify fs or f
   191  
   192  	fname := f.Name.String()
   193  	entry := tmplEntry{Fname: fname}
   194  
   195  	cbtype := new(ast.FuncType)
   196  	cbtype.Params = new(ast.FieldList)
   197  	for _, arg := range f.Type.Params.List {
   198  		for range arg.Names {
   199  			cbtype.Params.List = append(cbtype.Params.List,
   200  				&ast.Field{Type: arg.Type})
   201  		}
   202  	}
   203  
   204  	var args []*ast.Field
   205  
   206  	if f.Recv == nil {
   207  		// it's a function
   208  		entry.Flag = "__instrument_func_" + f.Name.String()
   209  		entry.CallbackType = nodeString(fs, cbtype)
   210  	} else {
   211  		// it's a method
   212  		recv := f.Recv.List[0]
   213  
   214  		cbtype.Params.List = append([]*ast.Field{&ast.Field{Type: recv.Type}},
   215  			cbtype.Params.List...)
   216  		entry.CallbackType = nodeString(fs, cbtype)
   217  
   218  		tstr := nodeString(fs, recv.Type)
   219  		entry.Flag = "__instrument_method_"
   220  		if strings.HasPrefix(tstr, "*") {
   221  			tmp := tstr[1:]
   222  			entry.Flag += tmp + "_" + fname
   223  		} else {
   224  			entry.Flag += tstr + "_" + fname
   225  		}
   226  		entry.Fname = "(" + tstr + ")." + fname
   227  		if len(recv.Names) == 0 {
   228  			args = append(args, &ast.Field{
   229  				Type:  recv.Type,
   230  				Names: []*ast.Ident{&ast.Ident{Name: "_"}},
   231  			})
   232  		} else {
   233  			args = append(args, recv)
   234  		}
   235  	}
   236  	for _, arg := range f.Type.Params.List {
   237  		if len(arg.Names) == 0 {
   238  			args = append(args, &ast.Field{
   239  				Type:  arg.Type,
   240  				Names: []*ast.Ident{&ast.Ident{Name: "_"}},
   241  			})
   242  		} else {
   243  			for _, name := range arg.Names {
   244  				args = append(args, &ast.Field{
   245  					Type:  arg.Type,
   246  					Names: []*ast.Ident{name},
   247  				})
   248  			}
   249  		}
   250  	}
   251  
   252  	// now that we have all the args, we can go through
   253  	// and figure out which ones are anonymous (and thus
   254  	// need their own dummy args)
   255  	var dummy int
   256  	for _, arg := range args {
   257  		var name string
   258  		if arg.Names[0].Name == "_" {
   259  			name = fmt.Sprintf("dummy%v", dummy)
   260  			dummy++
   261  			entry.DummyArgs = append(entry.DummyArgs, dummyArg{
   262  				Name: name,
   263  				Type: nodeString(fs, arg.Type),
   264  			})
   265  		} else {
   266  			name = arg.Names[0].Name
   267  		}
   268  		entry.Args = append(entry.Args, name)
   269  	}
   270  
   271  	return entry
   272  }
   273  
   274  func parseStmt(src string) *ast.BlockStmt {
   275  	src = `package main
   276  	func a() {` + src + `}`
   277  	fset := token.NewFileSet()
   278  	a, err := parser.ParseFile(fset, "", src, parser.ParseComments|parser.DeclarationErrors)
   279  	if err != nil {
   280  		panic(fmt.Errorf("internal error: %v", err))
   281  	}
   282  	body := a.Decls[0].(*ast.FuncDecl).Body
   283  	zeroPos(&body)
   284  	return body
   285  }
   286  
   287  // walk v and zero all values of type token.Pos
   288  func zeroPos(v interface{}) {
   289  	rv := reflect.ValueOf(v)
   290  	if rv.Kind() != reflect.Ptr {
   291  		panic("internal error")
   292  	}
   293  	zeroPosHelper(rv)
   294  }
   295  
   296  var posTyp = reflect.TypeOf(token.Pos(0))
   297  
   298  func zeroPosHelper(rv reflect.Value) {
   299  	if rv.Type() == posTyp {
   300  		rv.SetInt(0)
   301  		return
   302  	}
   303  	switch rv.Kind() {
   304  	case reflect.Ptr:
   305  		if !rv.IsNil() {
   306  			zeroPosHelper(rv.Elem())
   307  		}
   308  	case reflect.Slice, reflect.Array:
   309  		for i := 0; i < rv.Len(); i++ {
   310  			zeroPosHelper(rv.Index(i))
   311  		}
   312  	case reflect.Map:
   313  		keys := rv.MapKeys()
   314  		for _, k := range keys {
   315  			zeroPosHelper(rv.MapIndex(k))
   316  		}
   317  	case reflect.Struct:
   318  		for i := 0; i < rv.NumField(); i++ {
   319  			zeroPosHelper(rv.Field(i))
   320  		}
   321  	}
   322  }
   323  
   324  func nodeString(fs *token.FileSet, node interface{}) string {
   325  	var buf bytes.Buffer
   326  	err := format.Node(&buf, fs, node)
   327  	if err != nil {
   328  		panic(fmt.Errorf("unexpected internal error: %v", err))
   329  	}
   330  	return string(buf.Bytes())
   331  }
   332  
   333  var initTmpl *template.Template = template.Must(template.New("").Parse(`
   334  import "local/research/instrument"
   335  
   336  var (
   337  	{{range .}}{{.Flag}} bool
   338  {{end}})
   339  
   340  func init() {
   341  	{{range .}}instrument.RegisterFlag({{.Fname}}, &{{.Flag}})
   342  {{end}}}
   343  `))
   344  
   345  var shimTmpl = template.Must(template.New("").Parse(`
   346  if {{.Flag}} {
   347  	callback, ok := instrument.GetCallback({{.Fname}})
   348  	if ok {
   349  		{{range .DummyArgs}}var {{.Name}} {{.Type}}
   350  		{{end}}
   351  		callback.({{.CallbackType}})({{range .Args}}{{.}},{{end}})
   352  	}
   353  }
   354  `))