github.com/jhump/golang-x-tools@v0.0.0-20220218190644-4958d6d39439/refactor/eg/rewrite.go (about)

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package eg
     6  
     7  // This file defines the AST rewriting pass.
     8  // Most of it was plundered directly from
     9  // $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution).
    10  
    11  import (
    12  	"fmt"
    13  	"go/ast"
    14  	"go/token"
    15  	"go/types"
    16  	"os"
    17  	"reflect"
    18  	"sort"
    19  	"strconv"
    20  	"strings"
    21  
    22  	"github.com/jhump/golang-x-tools/go/ast/astutil"
    23  )
    24  
    25  // transformItem takes a reflect.Value representing a variable of type ast.Node
    26  // transforms its child elements recursively with apply, and then transforms the
    27  // actual element if it contains an expression.
    28  func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
    29  	// don't bother if val is invalid to start with
    30  	if !rv.IsValid() {
    31  		return reflect.Value{}, false, nil
    32  	}
    33  
    34  	rv, changed, newEnv := tr.apply(tr.transformItem, rv)
    35  
    36  	e := rvToExpr(rv)
    37  	if e == nil {
    38  		return rv, changed, newEnv
    39  	}
    40  
    41  	savedEnv := tr.env
    42  	tr.env = make(map[string]ast.Expr) // inefficient!  Use a slice of k/v pairs
    43  
    44  	if tr.matchExpr(tr.before, e) {
    45  		if tr.verbose {
    46  			fmt.Fprintf(os.Stderr, "%s matches %s",
    47  				astString(tr.fset, tr.before), astString(tr.fset, e))
    48  			if len(tr.env) > 0 {
    49  				fmt.Fprintf(os.Stderr, " with:")
    50  				for name, ast := range tr.env {
    51  					fmt.Fprintf(os.Stderr, " %s->%s",
    52  						name, astString(tr.fset, ast))
    53  				}
    54  			}
    55  			fmt.Fprintf(os.Stderr, "\n")
    56  		}
    57  		tr.nsubsts++
    58  
    59  		// Clone the replacement tree, performing parameter substitution.
    60  		// We update all positions to n.Pos() to aid comment placement.
    61  		rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
    62  			reflect.ValueOf(e.Pos()))
    63  		changed = true
    64  		newEnv = tr.env
    65  	}
    66  	tr.env = savedEnv
    67  
    68  	return rv, changed, newEnv
    69  }
    70  
    71  // Transform applies the transformation to the specified parsed file,
    72  // whose type information is supplied in info, and returns the number
    73  // of replacements that were made.
    74  //
    75  // It mutates the AST in place (the identity of the root node is
    76  // unchanged), and may add nodes for which no type information is
    77  // available in info.
    78  //
    79  // Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go.
    80  //
    81  func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int {
    82  	if !tr.seenInfos[info] {
    83  		tr.seenInfos[info] = true
    84  		mergeTypeInfo(tr.info, info)
    85  	}
    86  	tr.currentPkg = pkg
    87  	tr.nsubsts = 0
    88  
    89  	if tr.verbose {
    90  		fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
    91  		fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
    92  		fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts)
    93  	}
    94  
    95  	o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file))
    96  	if changed {
    97  		panic("BUG")
    98  	}
    99  	file2 := o.Interface().(*ast.File)
   100  
   101  	// By construction, the root node is unchanged.
   102  	if file != file2 {
   103  		panic("BUG")
   104  	}
   105  
   106  	// Add any necessary imports.
   107  	// TODO(adonovan): remove no-longer needed imports too.
   108  	if tr.nsubsts > 0 {
   109  		pkgs := make(map[string]*types.Package)
   110  		for obj := range tr.importedObjs {
   111  			pkgs[obj.Pkg().Path()] = obj.Pkg()
   112  		}
   113  
   114  		for _, imp := range file.Imports {
   115  			path, _ := strconv.Unquote(imp.Path.Value)
   116  			delete(pkgs, path)
   117  		}
   118  		delete(pkgs, pkg.Path()) // don't import self
   119  
   120  		// NB: AddImport may completely replace the AST!
   121  		// It thus renders info and tr.info no longer relevant to file.
   122  		var paths []string
   123  		for path := range pkgs {
   124  			paths = append(paths, path)
   125  		}
   126  		sort.Strings(paths)
   127  		for _, path := range paths {
   128  			astutil.AddImport(tr.fset, file, path)
   129  		}
   130  	}
   131  
   132  	tr.currentPkg = nil
   133  
   134  	return tr.nsubsts
   135  }
   136  
   137  // setValue is a wrapper for x.SetValue(y); it protects
   138  // the caller from panics if x cannot be changed to y.
   139  func setValue(x, y reflect.Value) {
   140  	// don't bother if y is invalid to start with
   141  	if !y.IsValid() {
   142  		return
   143  	}
   144  	defer func() {
   145  		if x := recover(); x != nil {
   146  			if s, ok := x.(string); ok &&
   147  				(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
   148  				// x cannot be set to y - ignore this rewrite
   149  				return
   150  			}
   151  			panic(x)
   152  		}
   153  	}()
   154  	x.Set(y)
   155  }
   156  
   157  // Values/types for special cases.
   158  var (
   159  	objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
   160  	scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))
   161  
   162  	identType        = reflect.TypeOf((*ast.Ident)(nil))
   163  	selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil))
   164  	objectPtrType    = reflect.TypeOf((*ast.Object)(nil))
   165  	statementType    = reflect.TypeOf((*ast.Stmt)(nil)).Elem()
   166  	positionType     = reflect.TypeOf(token.NoPos)
   167  	scopePtrType     = reflect.TypeOf((*ast.Scope)(nil))
   168  )
   169  
   170  // apply replaces each AST field x in val with f(x), returning val.
   171  // To avoid extra conversions, f operates on the reflect.Value form.
   172  // f takes a reflect.Value representing the variable to modify of type ast.Node.
   173  // It returns a reflect.Value containing the transformed value of type ast.Node,
   174  // whether any change was made, and a map of identifiers to ast.Expr (so we can
   175  // do contextually correct substitutions in the parent statements).
   176  func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
   177  	if !val.IsValid() {
   178  		return reflect.Value{}, false, nil
   179  	}
   180  
   181  	// *ast.Objects introduce cycles and are likely incorrect after
   182  	// rewrite; don't follow them but replace with nil instead
   183  	if val.Type() == objectPtrType {
   184  		return objectPtrNil, false, nil
   185  	}
   186  
   187  	// similarly for scopes: they are likely incorrect after a rewrite;
   188  	// replace them with nil
   189  	if val.Type() == scopePtrType {
   190  		return scopePtrNil, false, nil
   191  	}
   192  
   193  	switch v := reflect.Indirect(val); v.Kind() {
   194  	case reflect.Slice:
   195  		// no possible rewriting of statements.
   196  		if v.Type().Elem() != statementType {
   197  			changed := false
   198  			var envp map[string]ast.Expr
   199  			for i := 0; i < v.Len(); i++ {
   200  				e := v.Index(i)
   201  				o, localchanged, env := f(e)
   202  				if localchanged {
   203  					changed = true
   204  					// we clobber envp here,
   205  					// which means if we have two successive
   206  					// replacements inside the same statement
   207  					// we will only generate the setup for one of them.
   208  					envp = env
   209  				}
   210  				setValue(e, o)
   211  			}
   212  			return val, changed, envp
   213  		}
   214  
   215  		// statements are rewritten.
   216  		var out []ast.Stmt
   217  		for i := 0; i < v.Len(); i++ {
   218  			e := v.Index(i)
   219  			o, changed, env := f(e)
   220  			if changed {
   221  				for _, s := range tr.afterStmts {
   222  					t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface()
   223  					out = append(out, t.(ast.Stmt))
   224  				}
   225  			}
   226  			setValue(e, o)
   227  			out = append(out, e.Interface().(ast.Stmt))
   228  		}
   229  		return reflect.ValueOf(out), false, nil
   230  	case reflect.Struct:
   231  		changed := false
   232  		var envp map[string]ast.Expr
   233  		for i := 0; i < v.NumField(); i++ {
   234  			e := v.Field(i)
   235  			o, localchanged, env := f(e)
   236  			if localchanged {
   237  				changed = true
   238  				envp = env
   239  			}
   240  			setValue(e, o)
   241  		}
   242  		return val, changed, envp
   243  	case reflect.Interface:
   244  		e := v.Elem()
   245  		o, changed, env := f(e)
   246  		setValue(v, o)
   247  		return val, changed, env
   248  	}
   249  	return val, false, nil
   250  }
   251  
   252  // subst returns a copy of (replacement) pattern with values from env
   253  // substituted in place of wildcards and pos used as the position of
   254  // tokens from the pattern.  if env == nil, subst returns a copy of
   255  // pattern and doesn't change the line number information.
   256  func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value {
   257  	if !pattern.IsValid() {
   258  		return reflect.Value{}
   259  	}
   260  
   261  	// *ast.Objects introduce cycles and are likely incorrect after
   262  	// rewrite; don't follow them but replace with nil instead
   263  	if pattern.Type() == objectPtrType {
   264  		return objectPtrNil
   265  	}
   266  
   267  	// similarly for scopes: they are likely incorrect after a rewrite;
   268  	// replace them with nil
   269  	if pattern.Type() == scopePtrType {
   270  		return scopePtrNil
   271  	}
   272  
   273  	// Wildcard gets replaced with map value.
   274  	if env != nil && pattern.Type() == identType {
   275  		id := pattern.Interface().(*ast.Ident)
   276  		if old, ok := env[id.Name]; ok {
   277  			return tr.subst(nil, reflect.ValueOf(old), reflect.Value{})
   278  		}
   279  	}
   280  
   281  	// Emit qualified identifiers in the pattern by appropriate
   282  	// (possibly qualified) identifier in the input.
   283  	//
   284  	// The template cannot contain dot imports, so all identifiers
   285  	// for imported objects are explicitly qualified.
   286  	//
   287  	// We assume (unsoundly) that there are no dot or named
   288  	// imports in the input code, nor are any imported package
   289  	// names shadowed, so the usual normal qualified identifier
   290  	// syntax may be used.
   291  	// TODO(adonovan): fix: avoid this assumption.
   292  	//
   293  	// A refactoring may be applied to a package referenced by the
   294  	// template.  Objects belonging to the current package are
   295  	// denoted by unqualified identifiers.
   296  	//
   297  	if tr.importedObjs != nil && pattern.Type() == selectorExprType {
   298  		obj := isRef(pattern.Interface().(*ast.SelectorExpr), tr.info)
   299  		if obj != nil {
   300  			if sel, ok := tr.importedObjs[obj]; ok {
   301  				var id ast.Expr
   302  				if obj.Pkg() == tr.currentPkg {
   303  					id = sel.Sel // unqualified
   304  				} else {
   305  					id = sel // pkg-qualified
   306  				}
   307  
   308  				// Return a clone of id.
   309  				saved := tr.importedObjs
   310  				tr.importedObjs = nil // break cycle
   311  				r := tr.subst(nil, reflect.ValueOf(id), pos)
   312  				tr.importedObjs = saved
   313  				return r
   314  			}
   315  		}
   316  	}
   317  
   318  	if pos.IsValid() && pattern.Type() == positionType {
   319  		// use new position only if old position was valid in the first place
   320  		if old := pattern.Interface().(token.Pos); !old.IsValid() {
   321  			return pattern
   322  		}
   323  		return pos
   324  	}
   325  
   326  	// Otherwise copy.
   327  	switch p := pattern; p.Kind() {
   328  	case reflect.Slice:
   329  		v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
   330  		for i := 0; i < p.Len(); i++ {
   331  			v.Index(i).Set(tr.subst(env, p.Index(i), pos))
   332  		}
   333  		return v
   334  
   335  	case reflect.Struct:
   336  		v := reflect.New(p.Type()).Elem()
   337  		for i := 0; i < p.NumField(); i++ {
   338  			v.Field(i).Set(tr.subst(env, p.Field(i), pos))
   339  		}
   340  		return v
   341  
   342  	case reflect.Ptr:
   343  		v := reflect.New(p.Type()).Elem()
   344  		if elem := p.Elem(); elem.IsValid() {
   345  			v.Set(tr.subst(env, elem, pos).Addr())
   346  		}
   347  
   348  		// Duplicate type information for duplicated ast.Expr.
   349  		// All ast.Node implementations are *structs,
   350  		// so this case catches them all.
   351  		if e := rvToExpr(v); e != nil {
   352  			updateTypeInfo(tr.info, e, p.Interface().(ast.Expr))
   353  		}
   354  		return v
   355  
   356  	case reflect.Interface:
   357  		v := reflect.New(p.Type()).Elem()
   358  		if elem := p.Elem(); elem.IsValid() {
   359  			v.Set(tr.subst(env, elem, pos))
   360  		}
   361  		return v
   362  	}
   363  
   364  	return pattern
   365  }
   366  
   367  // -- utilities -------------------------------------------------------
   368  
   369  func rvToExpr(rv reflect.Value) ast.Expr {
   370  	if rv.CanInterface() {
   371  		if e, ok := rv.Interface().(ast.Expr); ok {
   372  			return e
   373  		}
   374  	}
   375  	return nil
   376  }
   377  
   378  // updateTypeInfo duplicates type information for the existing AST old
   379  // so that it also applies to duplicated AST new.
   380  func updateTypeInfo(info *types.Info, new, old ast.Expr) {
   381  	switch new := new.(type) {
   382  	case *ast.Ident:
   383  		orig := old.(*ast.Ident)
   384  		if obj, ok := info.Defs[orig]; ok {
   385  			info.Defs[new] = obj
   386  		}
   387  		if obj, ok := info.Uses[orig]; ok {
   388  			info.Uses[new] = obj
   389  		}
   390  
   391  	case *ast.SelectorExpr:
   392  		orig := old.(*ast.SelectorExpr)
   393  		if sel, ok := info.Selections[orig]; ok {
   394  			info.Selections[new] = sel
   395  		}
   396  	}
   397  
   398  	if tv, ok := info.Types[old]; ok {
   399  		info.Types[new] = tv
   400  	}
   401  }