github.com/brownsys/tracing-framework-go@v0.0.0-20161210174012-0542a62412fe/cmd/rewrite/rewrite.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/ast" 7 "go/format" 8 "go/parser" 9 "go/token" 10 "go/types" 11 "reflect" 12 "text/template" 13 14 "golang.org/x/tools/go/ast/astutil" 15 ) 16 17 func rewriteGos(fset *token.FileSet, info types.Info, qual types.Qualifier, f *ast.File) (changed bool, err error) { 18 rname := runtimeName(f) 19 err = mapStmts(f, func(s ast.Stmt) ([]ast.Stmt, error) { 20 if g, ok := s.(*ast.GoStmt); ok { 21 stmts, err := rewriteGoStmt(fset, info, qual, rname, g) 22 if stmts != nil { 23 changed = true 24 } 25 return stmts, err 26 } 27 return nil, nil 28 }) 29 if changed { 30 // astutil.AddNamedImport(fset, f, rname, "runtime") 31 astutil.AddImport(fset, f, "github.com/brownsys/tracing-framework-go/local") 32 // astutil.AddImport(fset, f, "runtime") 33 } 34 return changed, err 35 } 36 37 // runtimeName searches through f's imports to find whether 38 // the "runtime" package has been imported, and if not, whether 39 // another package whose name is also "runtime" has been 40 // imported (which would conflict if we were to add "runtime" 41 // as an import). It returns the name that should be used to 42 // identify the "runtime" package. 43 func runtimeName(f *ast.File) string { 44 for _, imp := range f.Imports { 45 if imp.Path.Value == `"runtime"` { 46 if imp.Name != nil { 47 return imp.Name.Name 48 } 49 return "runtime" 50 } 51 } 52 return "__runtime" 53 } 54 55 // nameForPackage searches through f's imports to find 56 // whether the package identified by the given path has 57 // been imported, and if not, whether another package 58 // whose name is the same has been imported (which would 59 // conflict if we were to add the given path as an 60 // import). It returns the name that should be used to 61 // identify the given package. 62 // 63 // TODO: does this actually implement the spec? 64 // func nameForPackage(f *ast.File, path, name string) string { 65 // path = '"' + path + '"' 66 // for _, imp := range f.Imports { 67 // if imp.Path.Value == path { 68 // if imp.Name != nil { 69 // return imp.Name.Name 70 // } 71 // return name 72 // } 73 // } 74 // return "__" + name 75 // } 76 77 // mapStmts walks v, searching for values of type []ast.Stmt 78 // or [x]ast.Stmt. After recurring into such values, it loops 79 // over the slice or array, and for each element, calls f. 80 // If f returns nil, the value is left as is. If f returns a 81 // non-nil slice (of any length, including 0), the contents 82 // of this slice replace the original value in the slice or array. 83 // 84 // If f ever returns a non-nil error, it is immediately returned. 85 func mapStmts(v ast.Node, f func(s ast.Stmt) ([]ast.Stmt, error)) error { 86 var blocks []*ast.BlockStmt 87 ast.Inspect(v, func(n ast.Node) bool { 88 if b, ok := n.(*ast.BlockStmt); ok { 89 blocks = append(blocks, b) 90 } 91 return true 92 }) 93 94 // make sure to process blocks backwards 95 // so that children are processed before parents 96 for i := len(blocks) - 1; i >= 0; i-- { 97 b := blocks[i] 98 var newStmts []ast.Stmt 99 for _, s := range b.List { 100 new, err := f(s) 101 if err != nil { 102 return err 103 } 104 if new == nil { 105 newStmts = append(newStmts, s) 106 } else { 107 newStmts = append(newStmts, new...) 108 } 109 } 110 b.List = newStmts 111 } 112 113 return nil 114 } 115 116 func rewriteGoStmt(fset *token.FileSet, info types.Info, qual types.Qualifier, rname string, g *ast.GoStmt) ([]ast.Stmt, error) { 117 ftyp := info.TypeOf(g.Call.Fun) 118 119 if ftyp == nil { 120 return nil, fmt.Errorf("%v: could not determine type of function", 121 fset.Position(g.Call.Fun.Pos())) 122 } 123 sig := ftyp.(*types.Signature) 124 125 var arg struct { 126 Runtime string 127 Func string 128 Typ string 129 DefArgs, InnerArgs, OuterArgs []string 130 } 131 132 arg.Runtime = rname 133 arg.Func = nodeString(fset, g.Call.Fun) 134 arg.Typ = types.TypeString(ftyp, qual) 135 136 params := sig.Params() 137 for i := 0; i < params.Len(); i++ { 138 name := fmt.Sprintf("arg%v", i) 139 if sig.Variadic() && i == params.Len()-1 { 140 typ := types.TypeString(params.At(i).Type().(*types.Slice).Elem(), qual) 141 arg.DefArgs = append(arg.DefArgs, name+" ..."+typ) 142 arg.InnerArgs = append(arg.InnerArgs, name+"...") 143 } else { 144 typ := types.TypeString(params.At(i).Type(), qual) 145 arg.DefArgs = append(arg.DefArgs, name+" "+typ) 146 arg.InnerArgs = append(arg.InnerArgs, name) 147 } 148 } 149 150 for i, a := range g.Call.Args { 151 if g.Call.Ellipsis.IsValid() && i == len(g.Call.Args)-1 { 152 // g.Call.Ellipsis.IsValid() is true if g is variadic 153 arg.OuterArgs = append(arg.OuterArgs, nodeString(fset, a)+"...") 154 } else { 155 arg.OuterArgs = append(arg.OuterArgs, nodeString(fset, a)) 156 } 157 } 158 159 var buf bytes.Buffer 160 err := goTmpl.Execute(&buf, arg) 161 if err != nil { 162 panic(fmt.Errorf("internal error: %v", err)) 163 } 164 return parseStmts(string(buf.Bytes())), nil 165 } 166 167 var goTmpl = template.Must(template.New("").Parse(` 168 go func(__f1 func(), __f2 {{.Typ}} {{range .DefArgs}},{{.}}{{end}}){ 169 __f1() 170 __f2({{range .InnerArgs}}{{.}},{{end}}) 171 }(local.GetSpawnCallback(), {{.Func}}{{range .OuterArgs}},{{.}}{{end}}) 172 `)) 173 174 func parseStmts(src string) []ast.Stmt { 175 src = `package main 176 func a() {` + src + `}` 177 fset := token.NewFileSet() 178 a, err := parser.ParseFile(fset, "", src, parser.ParseComments|parser.DeclarationErrors) 179 if err != nil { 180 panic(fmt.Errorf("internal error: %v", err)) 181 } 182 stmts := a.Decls[0].(*ast.FuncDecl).Body.List 183 zeroPos(&stmts) 184 return stmts 185 } 186 187 // walk v and zero all values of type token.Pos 188 func zeroPos(v interface{}) { 189 rv := reflect.ValueOf(v) 190 if rv.Kind() != reflect.Ptr { 191 panic("internal error") 192 } 193 zeroPosHelper(rv) 194 } 195 196 var posTyp = reflect.TypeOf(token.Pos(0)) 197 198 func zeroPosHelper(rv reflect.Value) { 199 if rv.Type() == posTyp { 200 rv.SetInt(0) 201 return 202 } 203 switch rv.Kind() { 204 case reflect.Ptr: 205 if !rv.IsNil() { 206 zeroPosHelper(rv.Elem()) 207 } 208 case reflect.Slice, reflect.Array: 209 for i := 0; i < rv.Len(); i++ { 210 zeroPosHelper(rv.Index(i)) 211 } 212 case reflect.Map: 213 keys := rv.MapKeys() 214 for _, k := range keys { 215 zeroPosHelper(rv.MapIndex(k)) 216 } 217 case reflect.Struct: 218 for i := 0; i < rv.NumField(); i++ { 219 zeroPosHelper(rv.Field(i)) 220 } 221 } 222 } 223 224 func nodeString(fset *token.FileSet, node interface{}) string { 225 var buf bytes.Buffer 226 err := format.Node(&buf, fset, node) 227 if err != nil { 228 panic(fmt.Errorf("unexpected internal error: %v", err)) 229 } 230 return string(buf.Bytes()) 231 } 232 233 func qualifierForFile(pkg *types.Package, f *ast.File) types.Qualifier { 234 pathToPackage := make(map[string]*types.Package) 235 for _, pkg := range pkg.Imports() { 236 pathToPackage[pkg.Path()] = pkg 237 } 238 239 m := make(map[*types.Package]string) 240 for _, imp := range f.Imports { 241 if imp.Path.Value == `"unsafe"` { 242 continue 243 } 244 // slice out quotation marks 245 l := len(imp.Path.Value) 246 pkg, ok := pathToPackage[imp.Path.Value[1:l-1]] 247 if !ok { 248 panic(fmt.Errorf("package %v (imported in %v) not in (*loader.Program).AllPackages", imp.Path.Value, f.Name.Name)) 249 } 250 name := "" 251 if imp.Name == nil { 252 name = pkg.Name() 253 } else { 254 name = imp.Name.Name 255 } 256 m[pkg] = name 257 } 258 return func(p *types.Package) string { return m[p] } 259 }