golang.org/x/tools@v0.21.0/refactor/eg/rewrite.go (about) 1 // Copyright 2014 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package eg 6 7 // This file defines the AST rewriting pass. 8 // Most of it was plundered directly from 9 // $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution). 10 11 import ( 12 "fmt" 13 "go/ast" 14 "go/token" 15 "go/types" 16 "os" 17 "reflect" 18 "sort" 19 "strconv" 20 "strings" 21 22 "golang.org/x/tools/go/ast/astutil" 23 ) 24 25 // transformItem takes a reflect.Value representing a variable of type ast.Node 26 // transforms its child elements recursively with apply, and then transforms the 27 // actual element if it contains an expression. 28 func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) { 29 // don't bother if val is invalid to start with 30 if !rv.IsValid() { 31 return reflect.Value{}, false, nil 32 } 33 34 rv, changed, newEnv := tr.apply(tr.transformItem, rv) 35 36 e := rvToExpr(rv) 37 if e == nil { 38 return rv, changed, newEnv 39 } 40 41 savedEnv := tr.env 42 tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs 43 44 if tr.matchExpr(tr.before, e) { 45 if tr.verbose { 46 fmt.Fprintf(os.Stderr, "%s matches %s", 47 astString(tr.fset, tr.before), astString(tr.fset, e)) 48 if len(tr.env) > 0 { 49 fmt.Fprintf(os.Stderr, " with:") 50 for name, ast := range tr.env { 51 fmt.Fprintf(os.Stderr, " %s->%s", 52 name, astString(tr.fset, ast)) 53 } 54 } 55 fmt.Fprintf(os.Stderr, "\n") 56 } 57 tr.nsubsts++ 58 59 // Clone the replacement tree, performing parameter substitution. 60 // We update all positions to n.Pos() to aid comment placement. 61 rv = tr.subst(tr.env, reflect.ValueOf(tr.after), 62 reflect.ValueOf(e.Pos())) 63 changed = true 64 newEnv = tr.env 65 } 66 tr.env = savedEnv 67 68 return rv, changed, newEnv 69 } 70 71 // Transform applies the transformation to the specified parsed file, 72 // whose type information is supplied in info, and returns the number 73 // of replacements that were made. 74 // 75 // It mutates the AST in place (the identity of the root node is 76 // unchanged), and may add nodes for which no type information is 77 // available in info. 78 // 79 // Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go. 80 func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int { 81 if !tr.seenInfos[info] { 82 tr.seenInfos[info] = true 83 mergeTypeInfo(tr.info, info) 84 } 85 tr.currentPkg = pkg 86 tr.nsubsts = 0 87 88 if tr.verbose { 89 fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before)) 90 fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after)) 91 fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts) 92 } 93 94 o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file)) 95 if changed { 96 panic("BUG") 97 } 98 file2 := o.Interface().(*ast.File) 99 100 // By construction, the root node is unchanged. 101 if file != file2 { 102 panic("BUG") 103 } 104 105 // Add any necessary imports. 106 // TODO(adonovan): remove no-longer needed imports too. 107 if tr.nsubsts > 0 { 108 pkgs := make(map[string]*types.Package) 109 for obj := range tr.importedObjs { 110 pkgs[obj.Pkg().Path()] = obj.Pkg() 111 } 112 113 for _, imp := range file.Imports { 114 path, _ := strconv.Unquote(imp.Path.Value) 115 delete(pkgs, path) 116 } 117 delete(pkgs, pkg.Path()) // don't import self 118 119 // NB: AddImport may completely replace the AST! 120 // It thus renders info and tr.info no longer relevant to file. 121 var paths []string 122 for path := range pkgs { 123 paths = append(paths, path) 124 } 125 sort.Strings(paths) 126 for _, path := range paths { 127 astutil.AddImport(tr.fset, file, path) 128 } 129 } 130 131 tr.currentPkg = nil 132 133 return tr.nsubsts 134 } 135 136 // setValue is a wrapper for x.SetValue(y); it protects 137 // the caller from panics if x cannot be changed to y. 138 func setValue(x, y reflect.Value) { 139 // don't bother if y is invalid to start with 140 if !y.IsValid() { 141 return 142 } 143 defer func() { 144 if x := recover(); x != nil { 145 if s, ok := x.(string); ok && 146 (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) { 147 // x cannot be set to y - ignore this rewrite 148 return 149 } 150 panic(x) 151 } 152 }() 153 x.Set(y) 154 } 155 156 // Values/types for special cases. 157 var ( 158 objectPtrNil = reflect.ValueOf((*ast.Object)(nil)) 159 scopePtrNil = reflect.ValueOf((*ast.Scope)(nil)) 160 161 identType = reflect.TypeOf((*ast.Ident)(nil)) 162 selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil)) 163 objectPtrType = reflect.TypeOf((*ast.Object)(nil)) 164 statementType = reflect.TypeOf((*ast.Stmt)(nil)).Elem() 165 positionType = reflect.TypeOf(token.NoPos) 166 scopePtrType = reflect.TypeOf((*ast.Scope)(nil)) 167 ) 168 169 // apply replaces each AST field x in val with f(x), returning val. 170 // To avoid extra conversions, f operates on the reflect.Value form. 171 // f takes a reflect.Value representing the variable to modify of type ast.Node. 172 // It returns a reflect.Value containing the transformed value of type ast.Node, 173 // whether any change was made, and a map of identifiers to ast.Expr (so we can 174 // do contextually correct substitutions in the parent statements). 175 func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) { 176 if !val.IsValid() { 177 return reflect.Value{}, false, nil 178 } 179 180 // *ast.Objects introduce cycles and are likely incorrect after 181 // rewrite; don't follow them but replace with nil instead 182 if val.Type() == objectPtrType { 183 return objectPtrNil, false, nil 184 } 185 186 // similarly for scopes: they are likely incorrect after a rewrite; 187 // replace them with nil 188 if val.Type() == scopePtrType { 189 return scopePtrNil, false, nil 190 } 191 192 switch v := reflect.Indirect(val); v.Kind() { 193 case reflect.Slice: 194 // no possible rewriting of statements. 195 if v.Type().Elem() != statementType { 196 changed := false 197 var envp map[string]ast.Expr 198 for i := 0; i < v.Len(); i++ { 199 e := v.Index(i) 200 o, localchanged, env := f(e) 201 if localchanged { 202 changed = true 203 // we clobber envp here, 204 // which means if we have two successive 205 // replacements inside the same statement 206 // we will only generate the setup for one of them. 207 envp = env 208 } 209 setValue(e, o) 210 } 211 return val, changed, envp 212 } 213 214 // statements are rewritten. 215 var out []ast.Stmt 216 for i := 0; i < v.Len(); i++ { 217 e := v.Index(i) 218 o, changed, env := f(e) 219 if changed { 220 for _, s := range tr.afterStmts { 221 t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface() 222 out = append(out, t.(ast.Stmt)) 223 } 224 } 225 setValue(e, o) 226 out = append(out, e.Interface().(ast.Stmt)) 227 } 228 return reflect.ValueOf(out), false, nil 229 case reflect.Struct: 230 changed := false 231 var envp map[string]ast.Expr 232 for i := 0; i < v.NumField(); i++ { 233 e := v.Field(i) 234 o, localchanged, env := f(e) 235 if localchanged { 236 changed = true 237 envp = env 238 } 239 setValue(e, o) 240 } 241 return val, changed, envp 242 case reflect.Interface: 243 e := v.Elem() 244 o, changed, env := f(e) 245 setValue(v, o) 246 return val, changed, env 247 } 248 return val, false, nil 249 } 250 251 // subst returns a copy of (replacement) pattern with values from env 252 // substituted in place of wildcards and pos used as the position of 253 // tokens from the pattern. if env == nil, subst returns a copy of 254 // pattern and doesn't change the line number information. 255 func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value { 256 if !pattern.IsValid() { 257 return reflect.Value{} 258 } 259 260 // *ast.Objects introduce cycles and are likely incorrect after 261 // rewrite; don't follow them but replace with nil instead 262 if pattern.Type() == objectPtrType { 263 return objectPtrNil 264 } 265 266 // similarly for scopes: they are likely incorrect after a rewrite; 267 // replace them with nil 268 if pattern.Type() == scopePtrType { 269 return scopePtrNil 270 } 271 272 // Wildcard gets replaced with map value. 273 if env != nil && pattern.Type() == identType { 274 id := pattern.Interface().(*ast.Ident) 275 if old, ok := env[id.Name]; ok { 276 return tr.subst(nil, reflect.ValueOf(old), reflect.Value{}) 277 } 278 } 279 280 // Emit qualified identifiers in the pattern by appropriate 281 // (possibly qualified) identifier in the input. 282 // 283 // The template cannot contain dot imports, so all identifiers 284 // for imported objects are explicitly qualified. 285 // 286 // We assume (unsoundly) that there are no dot or named 287 // imports in the input code, nor are any imported package 288 // names shadowed, so the usual normal qualified identifier 289 // syntax may be used. 290 // TODO(adonovan): fix: avoid this assumption. 291 // 292 // A refactoring may be applied to a package referenced by the 293 // template. Objects belonging to the current package are 294 // denoted by unqualified identifiers. 295 // 296 if tr.importedObjs != nil && pattern.Type() == selectorExprType { 297 obj := isRef(pattern.Interface().(*ast.SelectorExpr), tr.info) 298 if obj != nil { 299 if sel, ok := tr.importedObjs[obj]; ok { 300 var id ast.Expr 301 if obj.Pkg() == tr.currentPkg { 302 id = sel.Sel // unqualified 303 } else { 304 id = sel // pkg-qualified 305 } 306 307 // Return a clone of id. 308 saved := tr.importedObjs 309 tr.importedObjs = nil // break cycle 310 r := tr.subst(nil, reflect.ValueOf(id), pos) 311 tr.importedObjs = saved 312 return r 313 } 314 } 315 } 316 317 if pos.IsValid() && pattern.Type() == positionType { 318 // use new position only if old position was valid in the first place 319 if old := pattern.Interface().(token.Pos); !old.IsValid() { 320 return pattern 321 } 322 return pos 323 } 324 325 // Otherwise copy. 326 switch p := pattern; p.Kind() { 327 case reflect.Slice: 328 v := reflect.MakeSlice(p.Type(), p.Len(), p.Len()) 329 for i := 0; i < p.Len(); i++ { 330 v.Index(i).Set(tr.subst(env, p.Index(i), pos)) 331 } 332 return v 333 334 case reflect.Struct: 335 v := reflect.New(p.Type()).Elem() 336 for i := 0; i < p.NumField(); i++ { 337 v.Field(i).Set(tr.subst(env, p.Field(i), pos)) 338 } 339 return v 340 341 case reflect.Ptr: 342 v := reflect.New(p.Type()).Elem() 343 if elem := p.Elem(); elem.IsValid() { 344 v.Set(tr.subst(env, elem, pos).Addr()) 345 } 346 347 // Duplicate type information for duplicated ast.Expr. 348 // All ast.Node implementations are *structs, 349 // so this case catches them all. 350 if e := rvToExpr(v); e != nil { 351 updateTypeInfo(tr.info, e, p.Interface().(ast.Expr)) 352 } 353 return v 354 355 case reflect.Interface: 356 v := reflect.New(p.Type()).Elem() 357 if elem := p.Elem(); elem.IsValid() { 358 v.Set(tr.subst(env, elem, pos)) 359 } 360 return v 361 } 362 363 return pattern 364 } 365 366 // -- utilities ------------------------------------------------------- 367 368 func rvToExpr(rv reflect.Value) ast.Expr { 369 if rv.CanInterface() { 370 if e, ok := rv.Interface().(ast.Expr); ok { 371 return e 372 } 373 } 374 return nil 375 } 376 377 // updateTypeInfo duplicates type information for the existing AST old 378 // so that it also applies to duplicated AST new. 379 func updateTypeInfo(info *types.Info, new, old ast.Expr) { 380 switch new := new.(type) { 381 case *ast.Ident: 382 orig := old.(*ast.Ident) 383 if obj, ok := info.Defs[orig]; ok { 384 info.Defs[new] = obj 385 } 386 if obj, ok := info.Uses[orig]; ok { 387 info.Uses[new] = obj 388 } 389 390 case *ast.SelectorExpr: 391 orig := old.(*ast.SelectorExpr) 392 if sel, ok := info.Selections[orig]; ok { 393 info.Selections[new] = sel 394 } 395 } 396 397 if tv, ok := info.Types[old]; ok { 398 info.Types[new] = tv 399 } 400 }