github.com/jhump/golang-x-tools@v0.0.0-20220218190644-4958d6d39439/refactor/eg/rewrite.go (about) 1 // Copyright 2014 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package eg 6 7 // This file defines the AST rewriting pass. 8 // Most of it was plundered directly from 9 // $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution). 10 11 import ( 12 "fmt" 13 "go/ast" 14 "go/token" 15 "go/types" 16 "os" 17 "reflect" 18 "sort" 19 "strconv" 20 "strings" 21 22 "github.com/jhump/golang-x-tools/go/ast/astutil" 23 ) 24 25 // transformItem takes a reflect.Value representing a variable of type ast.Node 26 // transforms its child elements recursively with apply, and then transforms the 27 // actual element if it contains an expression. 28 func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) { 29 // don't bother if val is invalid to start with 30 if !rv.IsValid() { 31 return reflect.Value{}, false, nil 32 } 33 34 rv, changed, newEnv := tr.apply(tr.transformItem, rv) 35 36 e := rvToExpr(rv) 37 if e == nil { 38 return rv, changed, newEnv 39 } 40 41 savedEnv := tr.env 42 tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs 43 44 if tr.matchExpr(tr.before, e) { 45 if tr.verbose { 46 fmt.Fprintf(os.Stderr, "%s matches %s", 47 astString(tr.fset, tr.before), astString(tr.fset, e)) 48 if len(tr.env) > 0 { 49 fmt.Fprintf(os.Stderr, " with:") 50 for name, ast := range tr.env { 51 fmt.Fprintf(os.Stderr, " %s->%s", 52 name, astString(tr.fset, ast)) 53 } 54 } 55 fmt.Fprintf(os.Stderr, "\n") 56 } 57 tr.nsubsts++ 58 59 // Clone the replacement tree, performing parameter substitution. 60 // We update all positions to n.Pos() to aid comment placement. 61 rv = tr.subst(tr.env, reflect.ValueOf(tr.after), 62 reflect.ValueOf(e.Pos())) 63 changed = true 64 newEnv = tr.env 65 } 66 tr.env = savedEnv 67 68 return rv, changed, newEnv 69 } 70 71 // Transform applies the transformation to the specified parsed file, 72 // whose type information is supplied in info, and returns the number 73 // of replacements that were made. 74 // 75 // It mutates the AST in place (the identity of the root node is 76 // unchanged), and may add nodes for which no type information is 77 // available in info. 78 // 79 // Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go. 80 // 81 func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int { 82 if !tr.seenInfos[info] { 83 tr.seenInfos[info] = true 84 mergeTypeInfo(tr.info, info) 85 } 86 tr.currentPkg = pkg 87 tr.nsubsts = 0 88 89 if tr.verbose { 90 fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before)) 91 fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after)) 92 fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts) 93 } 94 95 o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file)) 96 if changed { 97 panic("BUG") 98 } 99 file2 := o.Interface().(*ast.File) 100 101 // By construction, the root node is unchanged. 102 if file != file2 { 103 panic("BUG") 104 } 105 106 // Add any necessary imports. 107 // TODO(adonovan): remove no-longer needed imports too. 108 if tr.nsubsts > 0 { 109 pkgs := make(map[string]*types.Package) 110 for obj := range tr.importedObjs { 111 pkgs[obj.Pkg().Path()] = obj.Pkg() 112 } 113 114 for _, imp := range file.Imports { 115 path, _ := strconv.Unquote(imp.Path.Value) 116 delete(pkgs, path) 117 } 118 delete(pkgs, pkg.Path()) // don't import self 119 120 // NB: AddImport may completely replace the AST! 121 // It thus renders info and tr.info no longer relevant to file. 122 var paths []string 123 for path := range pkgs { 124 paths = append(paths, path) 125 } 126 sort.Strings(paths) 127 for _, path := range paths { 128 astutil.AddImport(tr.fset, file, path) 129 } 130 } 131 132 tr.currentPkg = nil 133 134 return tr.nsubsts 135 } 136 137 // setValue is a wrapper for x.SetValue(y); it protects 138 // the caller from panics if x cannot be changed to y. 139 func setValue(x, y reflect.Value) { 140 // don't bother if y is invalid to start with 141 if !y.IsValid() { 142 return 143 } 144 defer func() { 145 if x := recover(); x != nil { 146 if s, ok := x.(string); ok && 147 (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) { 148 // x cannot be set to y - ignore this rewrite 149 return 150 } 151 panic(x) 152 } 153 }() 154 x.Set(y) 155 } 156 157 // Values/types for special cases. 158 var ( 159 objectPtrNil = reflect.ValueOf((*ast.Object)(nil)) 160 scopePtrNil = reflect.ValueOf((*ast.Scope)(nil)) 161 162 identType = reflect.TypeOf((*ast.Ident)(nil)) 163 selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil)) 164 objectPtrType = reflect.TypeOf((*ast.Object)(nil)) 165 statementType = reflect.TypeOf((*ast.Stmt)(nil)).Elem() 166 positionType = reflect.TypeOf(token.NoPos) 167 scopePtrType = reflect.TypeOf((*ast.Scope)(nil)) 168 ) 169 170 // apply replaces each AST field x in val with f(x), returning val. 171 // To avoid extra conversions, f operates on the reflect.Value form. 172 // f takes a reflect.Value representing the variable to modify of type ast.Node. 173 // It returns a reflect.Value containing the transformed value of type ast.Node, 174 // whether any change was made, and a map of identifiers to ast.Expr (so we can 175 // do contextually correct substitutions in the parent statements). 176 func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) { 177 if !val.IsValid() { 178 return reflect.Value{}, false, nil 179 } 180 181 // *ast.Objects introduce cycles and are likely incorrect after 182 // rewrite; don't follow them but replace with nil instead 183 if val.Type() == objectPtrType { 184 return objectPtrNil, false, nil 185 } 186 187 // similarly for scopes: they are likely incorrect after a rewrite; 188 // replace them with nil 189 if val.Type() == scopePtrType { 190 return scopePtrNil, false, nil 191 } 192 193 switch v := reflect.Indirect(val); v.Kind() { 194 case reflect.Slice: 195 // no possible rewriting of statements. 196 if v.Type().Elem() != statementType { 197 changed := false 198 var envp map[string]ast.Expr 199 for i := 0; i < v.Len(); i++ { 200 e := v.Index(i) 201 o, localchanged, env := f(e) 202 if localchanged { 203 changed = true 204 // we clobber envp here, 205 // which means if we have two successive 206 // replacements inside the same statement 207 // we will only generate the setup for one of them. 208 envp = env 209 } 210 setValue(e, o) 211 } 212 return val, changed, envp 213 } 214 215 // statements are rewritten. 216 var out []ast.Stmt 217 for i := 0; i < v.Len(); i++ { 218 e := v.Index(i) 219 o, changed, env := f(e) 220 if changed { 221 for _, s := range tr.afterStmts { 222 t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface() 223 out = append(out, t.(ast.Stmt)) 224 } 225 } 226 setValue(e, o) 227 out = append(out, e.Interface().(ast.Stmt)) 228 } 229 return reflect.ValueOf(out), false, nil 230 case reflect.Struct: 231 changed := false 232 var envp map[string]ast.Expr 233 for i := 0; i < v.NumField(); i++ { 234 e := v.Field(i) 235 o, localchanged, env := f(e) 236 if localchanged { 237 changed = true 238 envp = env 239 } 240 setValue(e, o) 241 } 242 return val, changed, envp 243 case reflect.Interface: 244 e := v.Elem() 245 o, changed, env := f(e) 246 setValue(v, o) 247 return val, changed, env 248 } 249 return val, false, nil 250 } 251 252 // subst returns a copy of (replacement) pattern with values from env 253 // substituted in place of wildcards and pos used as the position of 254 // tokens from the pattern. if env == nil, subst returns a copy of 255 // pattern and doesn't change the line number information. 256 func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value { 257 if !pattern.IsValid() { 258 return reflect.Value{} 259 } 260 261 // *ast.Objects introduce cycles and are likely incorrect after 262 // rewrite; don't follow them but replace with nil instead 263 if pattern.Type() == objectPtrType { 264 return objectPtrNil 265 } 266 267 // similarly for scopes: they are likely incorrect after a rewrite; 268 // replace them with nil 269 if pattern.Type() == scopePtrType { 270 return scopePtrNil 271 } 272 273 // Wildcard gets replaced with map value. 274 if env != nil && pattern.Type() == identType { 275 id := pattern.Interface().(*ast.Ident) 276 if old, ok := env[id.Name]; ok { 277 return tr.subst(nil, reflect.ValueOf(old), reflect.Value{}) 278 } 279 } 280 281 // Emit qualified identifiers in the pattern by appropriate 282 // (possibly qualified) identifier in the input. 283 // 284 // The template cannot contain dot imports, so all identifiers 285 // for imported objects are explicitly qualified. 286 // 287 // We assume (unsoundly) that there are no dot or named 288 // imports in the input code, nor are any imported package 289 // names shadowed, so the usual normal qualified identifier 290 // syntax may be used. 291 // TODO(adonovan): fix: avoid this assumption. 292 // 293 // A refactoring may be applied to a package referenced by the 294 // template. Objects belonging to the current package are 295 // denoted by unqualified identifiers. 296 // 297 if tr.importedObjs != nil && pattern.Type() == selectorExprType { 298 obj := isRef(pattern.Interface().(*ast.SelectorExpr), tr.info) 299 if obj != nil { 300 if sel, ok := tr.importedObjs[obj]; ok { 301 var id ast.Expr 302 if obj.Pkg() == tr.currentPkg { 303 id = sel.Sel // unqualified 304 } else { 305 id = sel // pkg-qualified 306 } 307 308 // Return a clone of id. 309 saved := tr.importedObjs 310 tr.importedObjs = nil // break cycle 311 r := tr.subst(nil, reflect.ValueOf(id), pos) 312 tr.importedObjs = saved 313 return r 314 } 315 } 316 } 317 318 if pos.IsValid() && pattern.Type() == positionType { 319 // use new position only if old position was valid in the first place 320 if old := pattern.Interface().(token.Pos); !old.IsValid() { 321 return pattern 322 } 323 return pos 324 } 325 326 // Otherwise copy. 327 switch p := pattern; p.Kind() { 328 case reflect.Slice: 329 v := reflect.MakeSlice(p.Type(), p.Len(), p.Len()) 330 for i := 0; i < p.Len(); i++ { 331 v.Index(i).Set(tr.subst(env, p.Index(i), pos)) 332 } 333 return v 334 335 case reflect.Struct: 336 v := reflect.New(p.Type()).Elem() 337 for i := 0; i < p.NumField(); i++ { 338 v.Field(i).Set(tr.subst(env, p.Field(i), pos)) 339 } 340 return v 341 342 case reflect.Ptr: 343 v := reflect.New(p.Type()).Elem() 344 if elem := p.Elem(); elem.IsValid() { 345 v.Set(tr.subst(env, elem, pos).Addr()) 346 } 347 348 // Duplicate type information for duplicated ast.Expr. 349 // All ast.Node implementations are *structs, 350 // so this case catches them all. 351 if e := rvToExpr(v); e != nil { 352 updateTypeInfo(tr.info, e, p.Interface().(ast.Expr)) 353 } 354 return v 355 356 case reflect.Interface: 357 v := reflect.New(p.Type()).Elem() 358 if elem := p.Elem(); elem.IsValid() { 359 v.Set(tr.subst(env, elem, pos)) 360 } 361 return v 362 } 363 364 return pattern 365 } 366 367 // -- utilities ------------------------------------------------------- 368 369 func rvToExpr(rv reflect.Value) ast.Expr { 370 if rv.CanInterface() { 371 if e, ok := rv.Interface().(ast.Expr); ok { 372 return e 373 } 374 } 375 return nil 376 } 377 378 // updateTypeInfo duplicates type information for the existing AST old 379 // so that it also applies to duplicated AST new. 380 func updateTypeInfo(info *types.Info, new, old ast.Expr) { 381 switch new := new.(type) { 382 case *ast.Ident: 383 orig := old.(*ast.Ident) 384 if obj, ok := info.Defs[orig]; ok { 385 info.Defs[new] = obj 386 } 387 if obj, ok := info.Uses[orig]; ok { 388 info.Uses[new] = obj 389 } 390 391 case *ast.SelectorExpr: 392 orig := old.(*ast.SelectorExpr) 393 if sel, ok := info.Selections[orig]; ok { 394 info.Selections[new] = sel 395 } 396 } 397 398 if tv, ok := info.Types[old]; ok { 399 info.Types[new] = tv 400 } 401 }