github.com/brownsys/tracing-framework-go@v0.0.0-20161210174012-0542a62412fe/other/cmd/rewrite/rewrite.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/ast" 7 "go/format" 8 "go/parser" 9 "go/token" 10 "go/types" 11 "reflect" 12 "strings" 13 "text/template" 14 15 "golang.org/x/tools/go/ast/astutil" 16 ) 17 18 func rewriteGos(fset *token.FileSet, info types.Info, qual types.Qualifier, f *ast.File) (changed bool, err error) { 19 rname := runtimeName(f) 20 err = mapStmts(f, func(s ast.Stmt) ([]ast.Stmt, error) { 21 if g, ok := s.(*ast.GoStmt); ok { 22 stmts, err := rewriteGoStmt(fset, info, qual, rname, g) 23 if stmts != nil { 24 changed = true 25 } 26 return stmts, err 27 } 28 return nil, nil 29 }) 30 if changed { 31 astutil.AddNamedImport(fset, f, rname, "runtime") 32 // astutil.AddImport(fset, f, "runtime") 33 } 34 return changed, err 35 } 36 37 func rewriteCalls(fset *token.FileSet, info types.Info, qual types.Qualifier, f *ast.File) (changed bool, err error) { 38 rname := runtimeName(f) 39 err = mapStmts(f, func(s ast.Stmt) ([]ast.Stmt, error) { 40 if a, ok := s.(*ast.AssignStmt); ok { 41 stmts, err := rewriteCallStmt(fset, info, qual, rname, a) 42 if stmts != nil { 43 changed = true 44 } 45 return stmts, err 46 } 47 return nil, nil 48 }) 49 if changed { 50 astutil.AddNamedImport(fset, f, rname, "runtime") 51 // astutil.AddImport(fset, f, "runtime") 52 } 53 return changed, err 54 } 55 56 // runtimeName searches through f's imports to find whether 57 // the "runtime" package has been imported, and if not, whether 58 // another package whose name is also "runtime" has been 59 // imported (which would conflict if we were to add "runtime" 60 // as an import). It returns the name that should be used to 61 // identify the "runtime" package. 62 func runtimeName(f *ast.File) string { 63 for _, imp := range f.Imports { 64 if imp.Path.Value == `"runtime"` { 65 if imp.Name != nil { 66 return imp.Name.Name 67 } 68 return "runtime" 69 } 70 } 71 return "__runtime" 72 } 73 74 // mapStmts walks v, searching for values of type []ast.Stmt 75 // or [x]ast.Stmt. After recurring into such values, it loops 76 // over the slice or array, and for each element, calls f. 77 // If f returns nil, the value is left as is. If f returns a 78 // non-nil slice (of any length, including 0), the contents 79 // of this slice replace the original value in the slice or array. 80 // 81 // If f ever returns a non-nil error, it is immediately returned. 82 func mapStmts(v ast.Node, f func(s ast.Stmt) ([]ast.Stmt, error)) error { 83 var blocks []*ast.BlockStmt 84 ast.Inspect(v, func(n ast.Node) bool { 85 if b, ok := n.(*ast.BlockStmt); ok { 86 blocks = append(blocks, b) 87 } 88 return true 89 }) 90 91 // make sure to process blocks backwards 92 // so that children are processed before parents 93 for i := len(blocks) - 1; i >= 0; i-- { 94 b := blocks[i] 95 var newStmts []ast.Stmt 96 for _, s := range b.List { 97 new, err := f(s) 98 if err != nil { 99 return err 100 } 101 if new == nil { 102 newStmts = append(newStmts, s) 103 } else { 104 newStmts = append(newStmts, new...) 105 } 106 } 107 b.List = newStmts 108 } 109 110 return nil 111 } 112 113 func rewriteCallStmt(fset *token.FileSet, info types.Info, qual types.Qualifier, rname string, a *ast.AssignStmt) ([]ast.Stmt, error) { 114 // for the time being, we only handle 115 // statements which have a single 116 // function call on the RHS, like: 117 // a, b = f() 118 119 if len(a.Rhs) != 1 { 120 for _, aa := range a.Rhs { 121 if _, ok := aa.(*ast.CallExpr); ok { 122 return nil, fmt.Errorf("%v: unsupported statement format", fset.Position(a.Pos())) 123 } 124 } 125 // none of the RHS expressions are function 126 // calls, so we can just safely ignore this 127 return nil, nil 128 } 129 130 c, ok := a.Rhs[0].(*ast.CallExpr) 131 if !ok { 132 return nil, nil 133 } 134 135 rettyp := info.TypeOf(c) 136 if rettyp == nil { 137 return nil, fmt.Errorf("%v: could not determine return type of function", 138 fset.Position(c.Pos())) 139 } 140 141 var vname string 142 143 // since the code has been type checked, 144 // we can assume that the function has 145 // at least one return value, and that 146 // len(LHS) = len(RHS) 147 if t, ok := rettyp.(*types.Tuple); ok { 148 context := false 149 for i := 0; i < t.Len(); i++ { 150 switch v := a.Lhs[i].(type) { 151 case *ast.Ident: 152 if v.Name != "_" && isContext(t.At(i).Type()) { 153 if context { 154 // more than one context.Context variable 155 return nil, fmt.Errorf("%v: unsupported statement format", fset.Position(a.Pos())) 156 } 157 context = true 158 vname = v.Name 159 } 160 default: 161 // TODO: handle LHS elements other than identifiers 162 return nil, nil 163 panic(fmt.Errorf("unexpected type %v", reflect.TypeOf(v))) 164 } 165 } 166 if !context { 167 return nil, nil 168 } 169 } else { 170 switch v := a.Lhs[0].(type) { 171 case *ast.Ident: 172 if v.Name == "_" || !isContext(rettyp) { 173 return nil, nil 174 } 175 vname = v.Name 176 default: 177 // TODO: handle LHS elements other than identifiers 178 return nil, nil 179 // panic(fmt.Errorf("unexpected type %v", reflect.TypeOf(v))) 180 } 181 } 182 183 arg := struct{ Runtime, Ctx string }{rname, vname} 184 185 var buf bytes.Buffer 186 err := callTmpl.Execute(&buf, arg) 187 if err != nil { 188 panic(fmt.Errorf("internal error: %v", err)) 189 } 190 return append([]ast.Stmt{a}, parseStmts(string(buf.Bytes()))...), nil 191 } 192 193 var callTmpl = template.Must(template.New("").Parse(`{{.Runtime}}.SetLocal({{.Ctx}})`)) 194 195 func rewriteGoStmt(fset *token.FileSet, info types.Info, qual types.Qualifier, rname string, g *ast.GoStmt) ([]ast.Stmt, error) { 196 ftyp := info.TypeOf(g.Call.Fun) 197 198 if ftyp == nil { 199 return nil, fmt.Errorf("%v: could not determine type of function", 200 fset.Position(g.Call.Fun.Pos())) 201 } 202 sig := ftyp.(*types.Signature) 203 204 // According to the context documentation: 205 // 206 // Do not store Contexts inside a struct type; 207 // instead, pass a Context explicitly to each 208 // function that needs it. The Context should 209 // be the first parameter, typically named ctx. 210 // 211 // Thus, we only handle this case. 212 if sig.Params().Len() == 0 || !isContext(sig.Params().At(0).Type()) { 213 return nil, nil 214 } 215 216 var arg struct { 217 Runtime string 218 Func string 219 Typ string 220 DefArgs, InnerArgs, OuterArgs []string 221 } 222 223 arg.Runtime = rname 224 arg.Func = nodeString(fset, g.Call.Fun) 225 arg.Typ = types.TypeString(ftyp, qual) 226 227 params := sig.Params() 228 for i := 0; i < params.Len(); i++ { 229 typ := types.TypeString(params.At(i).Type(), qual) 230 name := fmt.Sprintf("arg%v", i) 231 if sig.Variadic() && i == params.Len()-1 { 232 arg.DefArgs = append(arg.DefArgs, name+" ..."+typ) 233 arg.InnerArgs = append(arg.InnerArgs, name+"...") 234 } else { 235 arg.DefArgs = append(arg.DefArgs, name+" "+typ) 236 arg.InnerArgs = append(arg.InnerArgs, name) 237 } 238 } 239 240 for _, a := range g.Call.Args { 241 arg.OuterArgs = append(arg.OuterArgs, nodeString(fset, a)) 242 } 243 244 var buf bytes.Buffer 245 err := goTmpl.Execute(&buf, arg) 246 if err != nil { 247 panic(fmt.Errorf("internal error: %v", err)) 248 } 249 return parseStmts(string(buf.Bytes())), nil 250 } 251 252 var goTmpl = template.Must(template.New("").Parse(` 253 go func(__f {{.Typ}} {{range .DefArgs}},{{.}}{{end}}){ 254 {{.Runtime}}.SetLocal(arg0) 255 __f({{range .InnerArgs}}{{.}},{{end}}) 256 }({{.Func}}{{range .OuterArgs}},{{.}}{{end}}) 257 `)) 258 259 func parseStmts(src string) []ast.Stmt { 260 src = `package main 261 func a() {` + src + `}` 262 fset := token.NewFileSet() 263 a, err := parser.ParseFile(fset, "", src, parser.ParseComments|parser.DeclarationErrors) 264 if err != nil { 265 panic(fmt.Errorf("internal error: %v", err)) 266 } 267 stmts := a.Decls[0].(*ast.FuncDecl).Body.List 268 zeroPos(&stmts) 269 return stmts 270 } 271 272 // walk v and zero all values of type token.Pos 273 func zeroPos(v interface{}) { 274 rv := reflect.ValueOf(v) 275 if rv.Kind() != reflect.Ptr { 276 panic("internal error") 277 } 278 zeroPosHelper(rv) 279 } 280 281 var posTyp = reflect.TypeOf(token.Pos(0)) 282 283 func zeroPosHelper(rv reflect.Value) { 284 if rv.Type() == posTyp { 285 rv.SetInt(0) 286 return 287 } 288 switch rv.Kind() { 289 case reflect.Ptr: 290 if !rv.IsNil() { 291 zeroPosHelper(rv.Elem()) 292 } 293 case reflect.Slice, reflect.Array: 294 for i := 0; i < rv.Len(); i++ { 295 zeroPosHelper(rv.Index(i)) 296 } 297 case reflect.Map: 298 keys := rv.MapKeys() 299 for _, k := range keys { 300 zeroPosHelper(rv.MapIndex(k)) 301 } 302 case reflect.Struct: 303 for i := 0; i < rv.NumField(); i++ { 304 zeroPosHelper(rv.Field(i)) 305 } 306 } 307 } 308 309 func nodeString(fset *token.FileSet, node interface{}) string { 310 var buf bytes.Buffer 311 err := format.Node(&buf, fset, node) 312 if err != nil { 313 panic(fmt.Errorf("unexpected internal error: %v", err)) 314 } 315 return string(buf.Bytes()) 316 } 317 318 func qualifierForFile(pkg *types.Package, f *ast.File) types.Qualifier { 319 pathToPackage := make(map[string]*types.Package) 320 for _, pkg := range pkg.Imports() { 321 pathToPackage[pkg.Path()] = pkg 322 } 323 324 m := make(map[*types.Package]string) 325 for _, imp := range f.Imports { 326 if imp.Path.Value == `"unsafe"` { 327 continue 328 } 329 // slice out quotation marks 330 l := len(imp.Path.Value) 331 pkg, ok := pathToPackage[imp.Path.Value[1:l-1]] 332 if !ok { 333 panic(fmt.Errorf("package %v (imported in %v) not in (*loader.Program).AllPackages", imp.Path.Value, f.Name.Name)) 334 } 335 name := "" 336 if imp.Name == nil { 337 name = pkg.Name() 338 } else { 339 name = imp.Name.Name 340 } 341 m[pkg] = name 342 } 343 return func(p *types.Package) string { return m[p] } 344 } 345 346 func isContext(t types.Type) bool { 347 return t.String() == "golang.org/x/net/context.Context" || 348 t.String() == "context.Context" || strings.HasSuffix(t.String(), 349 "_workspace/src/golang.org/x/net/context.Context") 350 }