github.com/brownsys/tracing-framework-go@v0.0.0-20161210174012-0542a62412fe/other/cmd/instrument/main_instrument.go (about) 1 // +build instrument 2 3 package main 4 5 import ( 6 "bytes" 7 "fmt" 8 "go/ast" 9 "go/format" 10 "go/parser" 11 "go/printer" 12 "go/token" 13 "io/ioutil" 14 "local/research/instrument" 15 "os" 16 "path/filepath" 17 "reflect" 18 "strings" 19 "text/template" 20 21 "golang.org/x/tools/go/ast/astutil" 22 ) 23 24 func main() { 25 if __instrument_func_main { 26 callback, 27 28 ok := instrument. 29 GetCallback(main) 30 if ok { 31 callback.(func())() 32 } 33 } 34 35 processDir(".", func(f *ast.FuncDecl) bool { return true }) 36 } 37 38 type dummyArg struct { 39 Name, Type string 40 } 41 42 type tmplEntry struct { 43 Fname string 44 Flag string 45 CallbackType string 46 Args []string 47 DummyArgs []dummyArg 48 } 49 50 func processDir(path string, filter func(*ast.FuncDecl) bool) { 51 if __instrument_func_processDir { 52 callback, 53 ok := instrument. 54 GetCallback(processDir) 55 if ok { 56 callback.(func(string, func(*ast.FuncDecl) bool))(path, filter) 57 } 58 } 59 60 fs := token.NewFileSet() 61 parseFilter := func(fi os.FileInfo) bool { 62 if fi.Name() == "instrument_helper.go" || 63 strings.HasSuffix(fi.Name(), "_instrument.go") || 64 strings.HasSuffix(fi.Name(), "_test.go") { 65 return false 66 } 67 return true 68 } 69 pkgs, err := parser.ParseDir(fs, path, parseFilter, parser.ParseComments|parser.DeclarationErrors) 70 if err != nil { 71 fmt.Fprintf(os.Stderr, "could not parse package: %v\n", err) 72 os.Exit(2) 73 } 74 75 if len(pkgs) > 2 { 76 fmt.Fprintln(os.Stderr, "found multiple packages") 77 os.Exit(2) 78 } 79 80 if len(pkgs) == 0 { 81 os.Exit(0) 82 } 83 84 var entries []tmplEntry 85 var pkgname string 86 var pkg *ast.Package 87 88 for name, p := range pkgs { 89 pkgname = name 90 pkg = p 91 } 92 93 for fname, file := range pkg.Files { 94 _ = fname 95 for _, fnctmp := range file.Decls { 96 fnc, ok := fnctmp.(*ast.FuncDecl) 97 if !ok || !filter(fnc) { 98 continue 99 } 100 entry := funcToEntry(fs, fnc) 101 entries = append(entries, entry) 102 var buf bytes.Buffer 103 err := shimTmpl.Execute(&buf, entry) 104 if err != nil { 105 panic(fmt.Errorf("unexpected internal error: %v", err)) 106 } 107 stmt := parseStmt(string(buf.Bytes())) 108 if len(stmt.List) != 1 { 109 panic("internal error") 110 } 111 fnc.Body.List = append([]ast.Stmt{stmt.List[0]}, fnc.Body.List...) 112 } 113 114 astutil.AddImport(fs, file, "local/research/instrument") 115 116 origHasBuildTag := false 117 118 for _, c := range file.Comments { 119 for _, c := range c.List { 120 if c.Text == "// +build !instrument" { 121 c.Text = "// +build instrument" 122 origHasBuildTag = true 123 } 124 } 125 } 126 127 var buf bytes.Buffer 128 if origHasBuildTag { 129 printer.Fprint(&buf, fs, file) 130 } else { 131 buf.Write([]byte("// +build instrument\n\n")) 132 printer.Fprint(&buf, fs, file) 133 134 // prepend build comment to original file 135 b, err := ioutil.ReadFile(fname) 136 if err != nil { 137 fmt.Fprintf(os.Stderr, "could not read source file: %v\n", err) 138 os.Exit(2) 139 } 140 b = append([]byte("// +build !instrument\n\n"), b...) 141 b, err = format.Source(b) 142 if err != nil { 143 fmt.Fprintf(os.Stderr, "could not format source file %v: %v\n", fname, err) 144 os.Exit(2) 145 } 146 f, err := os.OpenFile(filepath.Join(path, fname), os.O_WRONLY, 0) 147 if err != nil { 148 fmt.Fprintf(os.Stderr, "could not open source file for writing: %v\n", err) 149 os.Exit(2) 150 } 151 if _, err = f.Write(b); err != nil { 152 fmt.Fprintf(os.Stderr, "could not write to source file: %v\n", err) 153 os.Exit(2) 154 } 155 } 156 157 b, err := format.Source(buf.Bytes()) 158 if err != nil { 159 panic(fmt.Errorf("unexpected internal error: %v", err)) 160 } 161 fpath := filepath.Join(path, fname[:len(fname)-3]+"_instrument.go") 162 if err = ioutil.WriteFile(fpath, b, 0664); err != nil { 163 fmt.Fprintf(os.Stderr, "could not create instrument source file: %v\n", err) 164 os.Exit(2) 165 } 166 } 167 168 // create a new slice of entries, this time 169 // deduplicated (in case the same functions 170 // appear multiple times across files with 171 // different build constraints) 172 seenEntries := make(map[string]bool) 173 var newEntries []tmplEntry 174 for _, e := range entries { 175 if seenEntries[e.Fname] { 176 continue 177 } 178 seenEntries[e.Fname] = true 179 newEntries = append(newEntries, e) 180 } 181 182 var buf bytes.Buffer 183 err = initTmpl.Execute(&buf, newEntries) 184 if err != nil { 185 panic(fmt.Errorf("unexpected internal error: %v", err)) 186 } 187 188 newbody := `// +build instrument 189 190 package ` + pkgname + string(buf.Bytes()) 191 192 b, err := format.Source([]byte(newbody)) 193 if err != nil { 194 panic(fmt.Errorf("unexpected internal error: %v", err)) 195 } 196 if err = ioutil.WriteFile(filepath.Join(path, "instrument_helper.go"), b, 0664); err != nil { 197 fmt.Fprintf(os.Stderr, "could not create instrument_helper.go: %v\n", err) 198 os.Exit(2) 199 } 200 } 201 202 func funcToEntry(fs *token.FileSet, f *ast.FuncDecl) tmplEntry { 203 if __instrument_func_funcToEntry { 204 callback, 205 ok := instrument. 206 GetCallback(funcToEntry) 207 if ok { 208 callback.(func(*token.FileSet, *ast. 209 FuncDecl))(fs, f) 210 } 211 } 212 213 // NOTE: throughout this function, it's important 214 // that we don't modify fs or f 215 216 fname := f.Name.String() 217 entry := tmplEntry{Fname: fname} 218 219 cbtype := new(ast.FuncType) 220 cbtype.Params = new(ast.FieldList) 221 for _, arg := range f.Type.Params.List { 222 for range arg.Names { 223 cbtype.Params.List = append(cbtype.Params.List, 224 &ast.Field{Type: arg.Type}) 225 } 226 } 227 228 var args []*ast.Field 229 230 if f.Recv == nil { 231 // it's a function 232 entry.Flag = "__instrument_func_" + f.Name.String() 233 entry.CallbackType = nodeString(fs, cbtype) 234 } else { 235 // it's a method 236 recv := f.Recv.List[0] 237 238 cbtype.Params.List = append([]*ast.Field{&ast.Field{Type: recv.Type}}, 239 cbtype.Params.List...) 240 entry.CallbackType = nodeString(fs, cbtype) 241 242 tstr := nodeString(fs, recv.Type) 243 entry.Flag = "__instrument_method_" 244 if strings.HasPrefix(tstr, "*") { 245 tmp := tstr[1:] 246 entry.Flag += tmp + "_" + fname 247 } else { 248 entry.Flag += tstr + "_" + fname 249 } 250 entry.Fname = "(" + tstr + ")." + fname 251 if len(recv.Names) == 0 { 252 args = append(args, &ast.Field{ 253 Type: recv.Type, 254 Names: []*ast.Ident{&ast.Ident{Name: "_"}}, 255 }) 256 } else { 257 args = append(args, recv) 258 } 259 } 260 for _, arg := range f.Type.Params.List { 261 if len(arg.Names) == 0 { 262 args = append(args, &ast.Field{ 263 Type: arg.Type, 264 Names: []*ast.Ident{&ast.Ident{Name: "_"}}, 265 }) 266 } else { 267 for _, name := range arg.Names { 268 args = append(args, &ast.Field{ 269 Type: arg.Type, 270 Names: []*ast.Ident{name}, 271 }) 272 } 273 } 274 } 275 276 // now that we have all the args, we can go through 277 // and figure out which ones are anonymous (and thus 278 // need their own dummy args) 279 var dummy int 280 for _, arg := range args { 281 var name string 282 if arg.Names[0].Name == "_" { 283 name = fmt.Sprintf("dummy%v", dummy) 284 dummy++ 285 entry.DummyArgs = append(entry.DummyArgs, dummyArg{ 286 Name: name, 287 Type: nodeString(fs, arg.Type), 288 }) 289 } else { 290 name = arg.Names[0].Name 291 } 292 entry.Args = append(entry.Args, name) 293 } 294 295 return entry 296 } 297 298 func parseStmt(src string) *ast.BlockStmt { 299 if __instrument_func_parseStmt { 300 callback, 301 ok := instrument. 302 GetCallback( 303 parseStmt) 304 if ok { 305 callback.(func(string))(src) 306 } 307 } 308 309 src = `package main 310 func a() {` + src + `}` 311 fset := token.NewFileSet() 312 a, err := parser.ParseFile(fset, "", src, parser.ParseComments|parser.DeclarationErrors) 313 if err != nil { 314 panic(fmt.Errorf("internal error: %v", err)) 315 } 316 body := a.Decls[0].(*ast.FuncDecl).Body 317 zeroPos(&body) 318 return body 319 } 320 321 // walk v and zero all values of type token.Pos 322 func zeroPos(v interface{}) { 323 if __instrument_func_zeroPos { 324 callback, 325 ok := 326 instrument.GetCallback(zeroPos) 327 if ok { 328 callback.(func(interface{}))(v) 329 } 330 } 331 332 rv := reflect.ValueOf(v) 333 if rv.Kind() != reflect.Ptr { 334 panic("internal error") 335 } 336 zeroPosHelper(rv) 337 } 338 339 var posTyp = reflect.TypeOf(token.Pos(0)) 340 341 func zeroPosHelper(rv reflect.Value) { 342 if __instrument_func_zeroPosHelper { 343 callback, 344 ok := instrument. 345 GetCallback(zeroPosHelper) 346 if ok { 347 callback.(func(reflect.Value))(rv) 348 } 349 } 350 351 if rv.Type() == posTyp { 352 rv.SetInt(0) 353 return 354 } 355 switch rv.Kind() { 356 case reflect.Ptr: 357 if !rv.IsNil() { 358 zeroPosHelper(rv.Elem()) 359 } 360 case reflect.Slice, reflect.Array: 361 for i := 0; i < rv.Len(); i++ { 362 zeroPosHelper(rv.Index(i)) 363 } 364 case reflect.Map: 365 keys := rv.MapKeys() 366 for _, k := range keys { 367 zeroPosHelper(rv.MapIndex(k)) 368 } 369 case reflect.Struct: 370 for i := 0; i < rv.NumField(); i++ { 371 zeroPosHelper(rv.Field(i)) 372 } 373 } 374 } 375 376 func nodeString(fs *token.FileSet, node interface{}) string { 377 if __instrument_func_nodeString { 378 callback, 379 ok := instrument. 380 GetCallback(nodeString) 381 if ok { 382 callback.(func(*token. 383 FileSet, interface{}))(fs, 384 node) 385 } 386 } 387 388 var buf bytes.Buffer 389 err := format.Node(&buf, fs, node) 390 if err != nil { 391 panic(fmt.Errorf("unexpected internal error: %v", err)) 392 } 393 return string(buf.Bytes()) 394 } 395 396 var initTmpl *template.Template = template.Must(template.New("").Parse(` 397 import "local/research/instrument" 398 399 var ( 400 {{range .}}{{.Flag}} bool 401 {{end}}) 402 403 func init() { 404 {{range .}}instrument.RegisterFlag({{.Fname}}, &{{.Flag}}) 405 {{end}}} 406 `)) 407 408 var shimTmpl = template.Must(template.New("").Parse(` 409 if {{.Flag}} { 410 callback, ok := instrument.GetCallback({{.Fname}}) 411 if ok { 412 {{range .DummyArgs}}var {{.Name}} {{.Type}} 413 {{end}} 414 callback.({{.CallbackType}})({{range .Args}}{{.}},{{end}}) 415 } 416 } 417 `))