github.com/brownsys/tracing-framework-go@v0.0.0-20161210174012-0542a62412fe/other/cmd/instrument/main.go (about) 1 // +build !instrument 2 3 package main 4 5 import ( 6 "bytes" 7 "fmt" 8 "go/ast" 9 "go/format" 10 "go/parser" 11 "go/printer" 12 "go/token" 13 "io/ioutil" 14 "os" 15 "path/filepath" 16 "reflect" 17 "strings" 18 "text/template" 19 20 "golang.org/x/tools/go/ast/astutil" 21 ) 22 23 const ( 24 helperFile = "instrument_helper.go" 25 fileSuffix = "_instrument.go" 26 buildTag = "instrumente" 27 ) 28 29 func main() { 30 processDir(".", func(f *ast.FuncDecl) bool { return f.Name.Name != "init" }) 31 } 32 33 type dummyArg struct { 34 Name, Type string 35 } 36 37 type tmplEntry struct { 38 Fname string 39 Flag string 40 CallbackType string 41 Args []string 42 DummyArgs []dummyArg 43 } 44 45 func processDir(path string, filter func(*ast.FuncDecl) bool) { 46 fs := token.NewFileSet() 47 parseFilter := func(fi os.FileInfo) bool { 48 if fi.Name() == helperFile || 49 strings.HasSuffix(fi.Name(), fileSuffix) || 50 strings.HasSuffix(fi.Name(), "_test.go") { 51 return false 52 } 53 return true 54 } 55 pkgs, err := parser.ParseDir(fs, path, parseFilter, parser.ParseComments|parser.DeclarationErrors) 56 if err != nil { 57 fmt.Fprintf(os.Stderr, "could not parse package: %v\n", err) 58 os.Exit(2) 59 } 60 61 if len(pkgs) > 2 { 62 fmt.Fprintln(os.Stderr, "found multiple packages") 63 os.Exit(2) 64 } 65 66 if len(pkgs) == 0 { 67 os.Exit(0) 68 } 69 70 var entries []tmplEntry 71 var pkgname string 72 var pkg *ast.Package 73 74 for name, p := range pkgs { 75 pkgname = name 76 pkg = p 77 } 78 79 for fname, file := range pkg.Files { 80 _ = fname 81 for _, fnctmp := range file.Decls { 82 fnc, ok := fnctmp.(*ast.FuncDecl) 83 if !ok || !filter(fnc) { 84 continue 85 } 86 entry := funcToEntry(fs, fnc) 87 entries = append(entries, entry) 88 var buf bytes.Buffer 89 err := shimTmpl.Execute(&buf, entry) 90 if err != nil { 91 panic(fmt.Errorf("unexpected internal error: %v", err)) 92 } 93 stmt := parseStmt(string(buf.Bytes())) 94 if len(stmt.List) != 1 { 95 panic("internal error") 96 } 97 fnc.Body.List = append([]ast.Stmt{stmt.List[0]}, fnc.Body.List...) 98 } 99 100 astutil.AddImport(fs, file, "github.com/brownsys/tracing-framework-go/instrument") 101 102 origHasBuildTag := false 103 104 for _, c := range file.Comments { 105 for _, c := range c.List { 106 if c.Text == "// +build !"+buildTag { 107 c.Text = "// +build " + buildTag 108 origHasBuildTag = true 109 } 110 } 111 } 112 113 var buf bytes.Buffer 114 if origHasBuildTag { 115 printer.Fprint(&buf, fs, file) 116 } else { 117 buf.Write([]byte("// +build " + buildTag + "\n\n")) 118 printer.Fprint(&buf, fs, file) 119 120 // prepend build comment to original file 121 b, err := ioutil.ReadFile(fname) 122 if err != nil { 123 fmt.Fprintf(os.Stderr, "could not read source file: %v\n", err) 124 os.Exit(2) 125 } 126 b = append([]byte("// +build !"+buildTag+"\n\n"), b...) 127 b, err = format.Source(b) 128 if err != nil { 129 fmt.Fprintf(os.Stderr, "could not format source file %v: %v\n", fname, err) 130 os.Exit(2) 131 } 132 f, err := os.OpenFile(filepath.Join(path, fname), os.O_WRONLY, 0) 133 if err != nil { 134 fmt.Fprintf(os.Stderr, "could not open source file for writing: %v\n", err) 135 os.Exit(2) 136 } 137 if _, err = f.Write(b); err != nil { 138 fmt.Fprintf(os.Stderr, "could not write to source file: %v\n", err) 139 os.Exit(2) 140 } 141 } 142 143 b, err := format.Source(buf.Bytes()) 144 if err != nil { 145 panic(fmt.Errorf("unexpected internal error: %v", err)) 146 } 147 fpath := filepath.Join(path, fname[:len(fname)-3]+fileSuffix) 148 if err = ioutil.WriteFile(fpath, b, 0664); err != nil { 149 fmt.Fprintf(os.Stderr, "could not create instrument source file: %v\n", err) 150 os.Exit(2) 151 } 152 } 153 154 // create a new slice of entries, this time 155 // deduplicated (in case the same functions 156 // appear multiple times across files with 157 // different build constraints) 158 seenEntries := make(map[string]bool) 159 var newEntries []tmplEntry 160 for _, e := range entries { 161 if seenEntries[e.Fname] { 162 continue 163 } 164 seenEntries[e.Fname] = true 165 newEntries = append(newEntries, e) 166 } 167 168 var buf bytes.Buffer 169 err = initTmpl.Execute(&buf, newEntries) 170 if err != nil { 171 panic(fmt.Errorf("unexpected internal error: %v", err)) 172 } 173 174 newbody := `// +build ` + buildTag + ` 175 176 package ` + pkgname + string(buf.Bytes()) 177 178 b, err := format.Source([]byte(newbody)) 179 if err != nil { 180 panic(fmt.Errorf("unexpected internal error: %v", err)) 181 } 182 if err = ioutil.WriteFile(filepath.Join(path, helperFile), b, 0664); err != nil { 183 fmt.Fprintf(os.Stderr, "could not create %v: %v\n", helperFile, err) 184 os.Exit(2) 185 } 186 } 187 188 func funcToEntry(fs *token.FileSet, f *ast.FuncDecl) tmplEntry { 189 // NOTE: throughout this function, it's important 190 // that we don't modify fs or f 191 192 fname := f.Name.String() 193 entry := tmplEntry{Fname: fname} 194 195 cbtype := new(ast.FuncType) 196 cbtype.Params = new(ast.FieldList) 197 for _, arg := range f.Type.Params.List { 198 for range arg.Names { 199 cbtype.Params.List = append(cbtype.Params.List, 200 &ast.Field{Type: arg.Type}) 201 } 202 } 203 204 var args []*ast.Field 205 206 if f.Recv == nil { 207 // it's a function 208 entry.Flag = "__instrument_func_" + f.Name.String() 209 entry.CallbackType = nodeString(fs, cbtype) 210 } else { 211 // it's a method 212 recv := f.Recv.List[0] 213 214 cbtype.Params.List = append([]*ast.Field{&ast.Field{Type: recv.Type}}, 215 cbtype.Params.List...) 216 entry.CallbackType = nodeString(fs, cbtype) 217 218 tstr := nodeString(fs, recv.Type) 219 entry.Flag = "__instrument_method_" 220 if strings.HasPrefix(tstr, "*") { 221 tmp := tstr[1:] 222 entry.Flag += tmp + "_" + fname 223 } else { 224 entry.Flag += tstr + "_" + fname 225 } 226 entry.Fname = "(" + tstr + ")." + fname 227 if len(recv.Names) == 0 { 228 args = append(args, &ast.Field{ 229 Type: recv.Type, 230 Names: []*ast.Ident{&ast.Ident{Name: "_"}}, 231 }) 232 } else { 233 args = append(args, recv) 234 } 235 } 236 for _, arg := range f.Type.Params.List { 237 if len(arg.Names) == 0 { 238 args = append(args, &ast.Field{ 239 Type: arg.Type, 240 Names: []*ast.Ident{&ast.Ident{Name: "_"}}, 241 }) 242 } else { 243 for _, name := range arg.Names { 244 args = append(args, &ast.Field{ 245 Type: arg.Type, 246 Names: []*ast.Ident{name}, 247 }) 248 } 249 } 250 } 251 252 // now that we have all the args, we can go through 253 // and figure out which ones are anonymous (and thus 254 // need their own dummy args) 255 var dummy int 256 for _, arg := range args { 257 var name string 258 if arg.Names[0].Name == "_" { 259 name = fmt.Sprintf("dummy%v", dummy) 260 dummy++ 261 entry.DummyArgs = append(entry.DummyArgs, dummyArg{ 262 Name: name, 263 Type: nodeString(fs, arg.Type), 264 }) 265 } else { 266 name = arg.Names[0].Name 267 } 268 entry.Args = append(entry.Args, name) 269 } 270 271 return entry 272 } 273 274 func parseStmt(src string) *ast.BlockStmt { 275 src = `package main 276 func a() {` + src + `}` 277 fset := token.NewFileSet() 278 a, err := parser.ParseFile(fset, "", src, parser.ParseComments|parser.DeclarationErrors) 279 if err != nil { 280 panic(fmt.Errorf("internal error: %v", err)) 281 } 282 body := a.Decls[0].(*ast.FuncDecl).Body 283 zeroPos(&body) 284 return body 285 } 286 287 // walk v and zero all values of type token.Pos 288 func zeroPos(v interface{}) { 289 rv := reflect.ValueOf(v) 290 if rv.Kind() != reflect.Ptr { 291 panic("internal error") 292 } 293 zeroPosHelper(rv) 294 } 295 296 var posTyp = reflect.TypeOf(token.Pos(0)) 297 298 func zeroPosHelper(rv reflect.Value) { 299 if rv.Type() == posTyp { 300 rv.SetInt(0) 301 return 302 } 303 switch rv.Kind() { 304 case reflect.Ptr: 305 if !rv.IsNil() { 306 zeroPosHelper(rv.Elem()) 307 } 308 case reflect.Slice, reflect.Array: 309 for i := 0; i < rv.Len(); i++ { 310 zeroPosHelper(rv.Index(i)) 311 } 312 case reflect.Map: 313 keys := rv.MapKeys() 314 for _, k := range keys { 315 zeroPosHelper(rv.MapIndex(k)) 316 } 317 case reflect.Struct: 318 for i := 0; i < rv.NumField(); i++ { 319 zeroPosHelper(rv.Field(i)) 320 } 321 } 322 } 323 324 func nodeString(fs *token.FileSet, node interface{}) string { 325 var buf bytes.Buffer 326 err := format.Node(&buf, fs, node) 327 if err != nil { 328 panic(fmt.Errorf("unexpected internal error: %v", err)) 329 } 330 return string(buf.Bytes()) 331 } 332 333 var initTmpl *template.Template = template.Must(template.New("").Parse(` 334 import "local/research/instrument" 335 336 var ( 337 {{range .}}{{.Flag}} bool 338 {{end}}) 339 340 func init() { 341 {{range .}}instrument.RegisterFlag({{.Fname}}, &{{.Flag}}) 342 {{end}}} 343 `)) 344 345 var shimTmpl = template.Must(template.New("").Parse(` 346 if {{.Flag}} { 347 callback, ok := instrument.GetCallback({{.Fname}}) 348 if ok { 349 {{range .DummyArgs}}var {{.Name}} {{.Type}} 350 {{end}} 351 callback.({{.CallbackType}})({{range .Args}}{{.}},{{end}}) 352 } 353 } 354 `))