github.com/brownsys/tracing-framework-go@v0.0.0-20161210174012-0542a62412fe/other/cmd/instrument/main_instrument.go (about)

     1  // +build instrument
     2  
     3  package main
     4  
     5  import (
     6  	"bytes"
     7  	"fmt"
     8  	"go/ast"
     9  	"go/format"
    10  	"go/parser"
    11  	"go/printer"
    12  	"go/token"
    13  	"io/ioutil"
    14  	"local/research/instrument"
    15  	"os"
    16  	"path/filepath"
    17  	"reflect"
    18  	"strings"
    19  	"text/template"
    20  
    21  	"golang.org/x/tools/go/ast/astutil"
    22  )
    23  
    24  func main() {
    25  	if __instrument_func_main {
    26  		callback,
    27  
    28  			ok := instrument.
    29  			GetCallback(main)
    30  		if ok {
    31  			callback.(func())()
    32  		}
    33  	}
    34  
    35  	processDir(".", func(f *ast.FuncDecl) bool { return true })
    36  }
    37  
    38  type dummyArg struct {
    39  	Name, Type string
    40  }
    41  
    42  type tmplEntry struct {
    43  	Fname        string
    44  	Flag         string
    45  	CallbackType string
    46  	Args         []string
    47  	DummyArgs    []dummyArg
    48  }
    49  
    50  func processDir(path string, filter func(*ast.FuncDecl) bool) {
    51  	if __instrument_func_processDir {
    52  		callback,
    53  			ok := instrument.
    54  			GetCallback(processDir)
    55  		if ok {
    56  			callback.(func(string, func(*ast.FuncDecl) bool))(path, filter)
    57  		}
    58  	}
    59  
    60  	fs := token.NewFileSet()
    61  	parseFilter := func(fi os.FileInfo) bool {
    62  		if fi.Name() == "instrument_helper.go" ||
    63  			strings.HasSuffix(fi.Name(), "_instrument.go") ||
    64  			strings.HasSuffix(fi.Name(), "_test.go") {
    65  			return false
    66  		}
    67  		return true
    68  	}
    69  	pkgs, err := parser.ParseDir(fs, path, parseFilter, parser.ParseComments|parser.DeclarationErrors)
    70  	if err != nil {
    71  		fmt.Fprintf(os.Stderr, "could not parse package: %v\n", err)
    72  		os.Exit(2)
    73  	}
    74  
    75  	if len(pkgs) > 2 {
    76  		fmt.Fprintln(os.Stderr, "found multiple packages")
    77  		os.Exit(2)
    78  	}
    79  
    80  	if len(pkgs) == 0 {
    81  		os.Exit(0)
    82  	}
    83  
    84  	var entries []tmplEntry
    85  	var pkgname string
    86  	var pkg *ast.Package
    87  
    88  	for name, p := range pkgs {
    89  		pkgname = name
    90  		pkg = p
    91  	}
    92  
    93  	for fname, file := range pkg.Files {
    94  		_ = fname
    95  		for _, fnctmp := range file.Decls {
    96  			fnc, ok := fnctmp.(*ast.FuncDecl)
    97  			if !ok || !filter(fnc) {
    98  				continue
    99  			}
   100  			entry := funcToEntry(fs, fnc)
   101  			entries = append(entries, entry)
   102  			var buf bytes.Buffer
   103  			err := shimTmpl.Execute(&buf, entry)
   104  			if err != nil {
   105  				panic(fmt.Errorf("unexpected internal error: %v", err))
   106  			}
   107  			stmt := parseStmt(string(buf.Bytes()))
   108  			if len(stmt.List) != 1 {
   109  				panic("internal error")
   110  			}
   111  			fnc.Body.List = append([]ast.Stmt{stmt.List[0]}, fnc.Body.List...)
   112  		}
   113  
   114  		astutil.AddImport(fs, file, "local/research/instrument")
   115  
   116  		origHasBuildTag := false
   117  
   118  		for _, c := range file.Comments {
   119  			for _, c := range c.List {
   120  				if c.Text == "// +build !instrument" {
   121  					c.Text = "// +build instrument"
   122  					origHasBuildTag = true
   123  				}
   124  			}
   125  		}
   126  
   127  		var buf bytes.Buffer
   128  		if origHasBuildTag {
   129  			printer.Fprint(&buf, fs, file)
   130  		} else {
   131  			buf.Write([]byte("// +build instrument\n\n"))
   132  			printer.Fprint(&buf, fs, file)
   133  
   134  			// prepend build comment to original file
   135  			b, err := ioutil.ReadFile(fname)
   136  			if err != nil {
   137  				fmt.Fprintf(os.Stderr, "could not read source file: %v\n", err)
   138  				os.Exit(2)
   139  			}
   140  			b = append([]byte("// +build !instrument\n\n"), b...)
   141  			b, err = format.Source(b)
   142  			if err != nil {
   143  				fmt.Fprintf(os.Stderr, "could not format source file %v: %v\n", fname, err)
   144  				os.Exit(2)
   145  			}
   146  			f, err := os.OpenFile(filepath.Join(path, fname), os.O_WRONLY, 0)
   147  			if err != nil {
   148  				fmt.Fprintf(os.Stderr, "could not open source file for writing: %v\n", err)
   149  				os.Exit(2)
   150  			}
   151  			if _, err = f.Write(b); err != nil {
   152  				fmt.Fprintf(os.Stderr, "could not write to source file: %v\n", err)
   153  				os.Exit(2)
   154  			}
   155  		}
   156  
   157  		b, err := format.Source(buf.Bytes())
   158  		if err != nil {
   159  			panic(fmt.Errorf("unexpected internal error: %v", err))
   160  		}
   161  		fpath := filepath.Join(path, fname[:len(fname)-3]+"_instrument.go")
   162  		if err = ioutil.WriteFile(fpath, b, 0664); err != nil {
   163  			fmt.Fprintf(os.Stderr, "could not create instrument source file: %v\n", err)
   164  			os.Exit(2)
   165  		}
   166  	}
   167  
   168  	// create a new slice of entries, this time
   169  	// deduplicated (in case the same functions
   170  	// appear multiple times across files with
   171  	// different build constraints)
   172  	seenEntries := make(map[string]bool)
   173  	var newEntries []tmplEntry
   174  	for _, e := range entries {
   175  		if seenEntries[e.Fname] {
   176  			continue
   177  		}
   178  		seenEntries[e.Fname] = true
   179  		newEntries = append(newEntries, e)
   180  	}
   181  
   182  	var buf bytes.Buffer
   183  	err = initTmpl.Execute(&buf, newEntries)
   184  	if err != nil {
   185  		panic(fmt.Errorf("unexpected internal error: %v", err))
   186  	}
   187  
   188  	newbody := `// +build instrument
   189  	
   190  package ` + pkgname + string(buf.Bytes())
   191  
   192  	b, err := format.Source([]byte(newbody))
   193  	if err != nil {
   194  		panic(fmt.Errorf("unexpected internal error: %v", err))
   195  	}
   196  	if err = ioutil.WriteFile(filepath.Join(path, "instrument_helper.go"), b, 0664); err != nil {
   197  		fmt.Fprintf(os.Stderr, "could not create instrument_helper.go: %v\n", err)
   198  		os.Exit(2)
   199  	}
   200  }
   201  
   202  func funcToEntry(fs *token.FileSet, f *ast.FuncDecl) tmplEntry {
   203  	if __instrument_func_funcToEntry {
   204  		callback,
   205  			ok := instrument.
   206  			GetCallback(funcToEntry)
   207  		if ok {
   208  			callback.(func(*token.FileSet, *ast.
   209  				FuncDecl))(fs, f)
   210  		}
   211  	}
   212  
   213  	// NOTE: throughout this function, it's important
   214  	// that we don't modify fs or f
   215  
   216  	fname := f.Name.String()
   217  	entry := tmplEntry{Fname: fname}
   218  
   219  	cbtype := new(ast.FuncType)
   220  	cbtype.Params = new(ast.FieldList)
   221  	for _, arg := range f.Type.Params.List {
   222  		for range arg.Names {
   223  			cbtype.Params.List = append(cbtype.Params.List,
   224  				&ast.Field{Type: arg.Type})
   225  		}
   226  	}
   227  
   228  	var args []*ast.Field
   229  
   230  	if f.Recv == nil {
   231  		// it's a function
   232  		entry.Flag = "__instrument_func_" + f.Name.String()
   233  		entry.CallbackType = nodeString(fs, cbtype)
   234  	} else {
   235  		// it's a method
   236  		recv := f.Recv.List[0]
   237  
   238  		cbtype.Params.List = append([]*ast.Field{&ast.Field{Type: recv.Type}},
   239  			cbtype.Params.List...)
   240  		entry.CallbackType = nodeString(fs, cbtype)
   241  
   242  		tstr := nodeString(fs, recv.Type)
   243  		entry.Flag = "__instrument_method_"
   244  		if strings.HasPrefix(tstr, "*") {
   245  			tmp := tstr[1:]
   246  			entry.Flag += tmp + "_" + fname
   247  		} else {
   248  			entry.Flag += tstr + "_" + fname
   249  		}
   250  		entry.Fname = "(" + tstr + ")." + fname
   251  		if len(recv.Names) == 0 {
   252  			args = append(args, &ast.Field{
   253  				Type:  recv.Type,
   254  				Names: []*ast.Ident{&ast.Ident{Name: "_"}},
   255  			})
   256  		} else {
   257  			args = append(args, recv)
   258  		}
   259  	}
   260  	for _, arg := range f.Type.Params.List {
   261  		if len(arg.Names) == 0 {
   262  			args = append(args, &ast.Field{
   263  				Type:  arg.Type,
   264  				Names: []*ast.Ident{&ast.Ident{Name: "_"}},
   265  			})
   266  		} else {
   267  			for _, name := range arg.Names {
   268  				args = append(args, &ast.Field{
   269  					Type:  arg.Type,
   270  					Names: []*ast.Ident{name},
   271  				})
   272  			}
   273  		}
   274  	}
   275  
   276  	// now that we have all the args, we can go through
   277  	// and figure out which ones are anonymous (and thus
   278  	// need their own dummy args)
   279  	var dummy int
   280  	for _, arg := range args {
   281  		var name string
   282  		if arg.Names[0].Name == "_" {
   283  			name = fmt.Sprintf("dummy%v", dummy)
   284  			dummy++
   285  			entry.DummyArgs = append(entry.DummyArgs, dummyArg{
   286  				Name: name,
   287  				Type: nodeString(fs, arg.Type),
   288  			})
   289  		} else {
   290  			name = arg.Names[0].Name
   291  		}
   292  		entry.Args = append(entry.Args, name)
   293  	}
   294  
   295  	return entry
   296  }
   297  
   298  func parseStmt(src string) *ast.BlockStmt {
   299  	if __instrument_func_parseStmt {
   300  		callback,
   301  			ok := instrument.
   302  			GetCallback(
   303  				parseStmt)
   304  		if ok {
   305  			callback.(func(string))(src)
   306  		}
   307  	}
   308  
   309  	src = `package main
   310  	func a() {` + src + `}`
   311  	fset := token.NewFileSet()
   312  	a, err := parser.ParseFile(fset, "", src, parser.ParseComments|parser.DeclarationErrors)
   313  	if err != nil {
   314  		panic(fmt.Errorf("internal error: %v", err))
   315  	}
   316  	body := a.Decls[0].(*ast.FuncDecl).Body
   317  	zeroPos(&body)
   318  	return body
   319  }
   320  
   321  // walk v and zero all values of type token.Pos
   322  func zeroPos(v interface{}) {
   323  	if __instrument_func_zeroPos {
   324  		callback,
   325  			ok :=
   326  			instrument.GetCallback(zeroPos)
   327  		if ok {
   328  			callback.(func(interface{}))(v)
   329  		}
   330  	}
   331  
   332  	rv := reflect.ValueOf(v)
   333  	if rv.Kind() != reflect.Ptr {
   334  		panic("internal error")
   335  	}
   336  	zeroPosHelper(rv)
   337  }
   338  
   339  var posTyp = reflect.TypeOf(token.Pos(0))
   340  
   341  func zeroPosHelper(rv reflect.Value) {
   342  	if __instrument_func_zeroPosHelper {
   343  		callback,
   344  			ok := instrument.
   345  			GetCallback(zeroPosHelper)
   346  		if ok {
   347  			callback.(func(reflect.Value))(rv)
   348  		}
   349  	}
   350  
   351  	if rv.Type() == posTyp {
   352  		rv.SetInt(0)
   353  		return
   354  	}
   355  	switch rv.Kind() {
   356  	case reflect.Ptr:
   357  		if !rv.IsNil() {
   358  			zeroPosHelper(rv.Elem())
   359  		}
   360  	case reflect.Slice, reflect.Array:
   361  		for i := 0; i < rv.Len(); i++ {
   362  			zeroPosHelper(rv.Index(i))
   363  		}
   364  	case reflect.Map:
   365  		keys := rv.MapKeys()
   366  		for _, k := range keys {
   367  			zeroPosHelper(rv.MapIndex(k))
   368  		}
   369  	case reflect.Struct:
   370  		for i := 0; i < rv.NumField(); i++ {
   371  			zeroPosHelper(rv.Field(i))
   372  		}
   373  	}
   374  }
   375  
   376  func nodeString(fs *token.FileSet, node interface{}) string {
   377  	if __instrument_func_nodeString {
   378  		callback,
   379  			ok := instrument.
   380  			GetCallback(nodeString)
   381  		if ok {
   382  			callback.(func(*token.
   383  				FileSet, interface{}))(fs,
   384  				node)
   385  		}
   386  	}
   387  
   388  	var buf bytes.Buffer
   389  	err := format.Node(&buf, fs, node)
   390  	if err != nil {
   391  		panic(fmt.Errorf("unexpected internal error: %v", err))
   392  	}
   393  	return string(buf.Bytes())
   394  }
   395  
   396  var initTmpl *template.Template = template.Must(template.New("").Parse(`
   397  import "local/research/instrument"
   398  
   399  var (
   400  	{{range .}}{{.Flag}} bool
   401  {{end}})
   402  
   403  func init() {
   404  	{{range .}}instrument.RegisterFlag({{.Fname}}, &{{.Flag}})
   405  {{end}}}
   406  `))
   407  
   408  var shimTmpl = template.Must(template.New("").Parse(`
   409  if {{.Flag}} {
   410  	callback, ok := instrument.GetCallback({{.Fname}})
   411  	if ok {
   412  		{{range .DummyArgs}}var {{.Name}} {{.Type}}
   413  		{{end}}
   414  		callback.({{.CallbackType}})({{range .Args}}{{.}},{{end}})
   415  	}
   416  }
   417  `))