github.com/golangci/gofmt@v0.0.0-20231018234816-f50ced29576e/gofmt/rewrite.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gofmt
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/parser"
    11  	"go/token"
    12  	"os"
    13  	"reflect"
    14  	"strings"
    15  	"unicode"
    16  	"unicode/utf8"
    17  )
    18  
    19  func initRewrite() {
    20  	if *rewriteRule == "" {
    21  		rewrite = nil // disable any previous rewrite
    22  		return
    23  	}
    24  	f := strings.Split(*rewriteRule, "->")
    25  	if len(f) != 2 {
    26  		fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
    27  		os.Exit(2)
    28  	}
    29  	pattern := parseExpr(f[0], "pattern")
    30  	replace := parseExpr(f[1], "replacement")
    31  	rewrite = func(fset *token.FileSet, p *ast.File) *ast.File {
    32  		return rewriteFile(fset, pattern, replace, p)
    33  	}
    34  }
    35  
    36  // parseExpr parses s as an expression.
    37  // It might make sense to expand this to allow statement patterns,
    38  // but there are problems with preserving formatting and also
    39  // with what a wildcard for a statement looks like.
    40  func parseExpr(s, what string) ast.Expr {
    41  	x, err := parser.ParseExpr(s)
    42  	if err != nil {
    43  		fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
    44  		os.Exit(2)
    45  	}
    46  	return x
    47  }
    48  
    49  // Keep this function for debugging.
    50  /*
    51  func dump(msg string, val reflect.Value) {
    52  	fmt.Printf("%s:\n", msg)
    53  	ast.Print(fileSet, val.Interface())
    54  	fmt.Println()
    55  }
    56  */
    57  
    58  // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
    59  func rewriteFile(fileSet *token.FileSet, pattern, replace ast.Expr, p *ast.File) *ast.File {
    60  	cmap := ast.NewCommentMap(fileSet, p, p.Comments)
    61  	m := make(map[string]reflect.Value)
    62  	pat := reflect.ValueOf(pattern)
    63  	repl := reflect.ValueOf(replace)
    64  
    65  	var rewriteVal func(val reflect.Value) reflect.Value
    66  	rewriteVal = func(val reflect.Value) reflect.Value {
    67  		// don't bother if val is invalid to start with
    68  		if !val.IsValid() {
    69  			return reflect.Value{}
    70  		}
    71  		val = apply(rewriteVal, val)
    72  		for k := range m {
    73  			delete(m, k)
    74  		}
    75  		if match(m, pat, val) {
    76  			val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
    77  		}
    78  		return val
    79  	}
    80  
    81  	r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
    82  	r.Comments = cmap.Filter(r).Comments() // recreate comments list
    83  	return r
    84  }
    85  
    86  // set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y.
    87  func set(x, y reflect.Value) {
    88  	// don't bother if x cannot be set or y is invalid
    89  	if !x.CanSet() || !y.IsValid() {
    90  		return
    91  	}
    92  	defer func() {
    93  		if x := recover(); x != nil {
    94  			if s, ok := x.(string); ok &&
    95  				(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
    96  				// x cannot be set to y - ignore this rewrite
    97  				return
    98  			}
    99  			panic(x)
   100  		}
   101  	}()
   102  	x.Set(y)
   103  }
   104  
   105  // Values/types for special cases.
   106  var (
   107  	objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
   108  	scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))
   109  
   110  	identType     = reflect.TypeOf((*ast.Ident)(nil))
   111  	objectPtrType = reflect.TypeOf((*ast.Object)(nil))
   112  	positionType  = reflect.TypeOf(token.NoPos)
   113  	callExprType  = reflect.TypeOf((*ast.CallExpr)(nil))
   114  	scopePtrType  = reflect.TypeOf((*ast.Scope)(nil))
   115  )
   116  
   117  // apply replaces each AST field x in val with f(x), returning val.
   118  // To avoid extra conversions, f operates on the reflect.Value form.
   119  func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
   120  	if !val.IsValid() {
   121  		return reflect.Value{}
   122  	}
   123  
   124  	// *ast.Objects introduce cycles and are likely incorrect after
   125  	// rewrite; don't follow them but replace with nil instead
   126  	if val.Type() == objectPtrType {
   127  		return objectPtrNil
   128  	}
   129  
   130  	// similarly for scopes: they are likely incorrect after a rewrite;
   131  	// replace them with nil
   132  	if val.Type() == scopePtrType {
   133  		return scopePtrNil
   134  	}
   135  
   136  	switch v := reflect.Indirect(val); v.Kind() {
   137  	case reflect.Slice:
   138  		for i := 0; i < v.Len(); i++ {
   139  			e := v.Index(i)
   140  			set(e, f(e))
   141  		}
   142  	case reflect.Struct:
   143  		for i := 0; i < v.NumField(); i++ {
   144  			e := v.Field(i)
   145  			set(e, f(e))
   146  		}
   147  	case reflect.Interface:
   148  		e := v.Elem()
   149  		set(v, f(e))
   150  	}
   151  	return val
   152  }
   153  
   154  func isWildcard(s string) bool {
   155  	rune, size := utf8.DecodeRuneInString(s)
   156  	return size == len(s) && unicode.IsLower(rune)
   157  }
   158  
   159  // match reports whether pattern matches val,
   160  // recording wildcard submatches in m.
   161  // If m == nil, match checks whether pattern == val.
   162  func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
   163  	// Wildcard matches any expression. If it appears multiple
   164  	// times in the pattern, it must match the same expression
   165  	// each time.
   166  	if m != nil && pattern.IsValid() && pattern.Type() == identType {
   167  		name := pattern.Interface().(*ast.Ident).Name
   168  		if isWildcard(name) && val.IsValid() {
   169  			// wildcards only match valid (non-nil) expressions.
   170  			if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
   171  				if old, ok := m[name]; ok {
   172  					return match(nil, old, val)
   173  				}
   174  				m[name] = val
   175  				return true
   176  			}
   177  		}
   178  	}
   179  
   180  	// Otherwise, pattern and val must match recursively.
   181  	if !pattern.IsValid() || !val.IsValid() {
   182  		return !pattern.IsValid() && !val.IsValid()
   183  	}
   184  	if pattern.Type() != val.Type() {
   185  		return false
   186  	}
   187  
   188  	// Special cases.
   189  	switch pattern.Type() {
   190  	case identType:
   191  		// For identifiers, only the names need to match
   192  		// (and none of the other *ast.Object information).
   193  		// This is a common case, handle it all here instead
   194  		// of recursing down any further via reflection.
   195  		p := pattern.Interface().(*ast.Ident)
   196  		v := val.Interface().(*ast.Ident)
   197  		return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
   198  	case objectPtrType, positionType:
   199  		// object pointers and token positions always match
   200  		return true
   201  	case callExprType:
   202  		// For calls, the Ellipsis fields (token.Position) must
   203  		// match since that is how f(x) and f(x...) are different.
   204  		// Check them here but fall through for the remaining fields.
   205  		p := pattern.Interface().(*ast.CallExpr)
   206  		v := val.Interface().(*ast.CallExpr)
   207  		if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
   208  			return false
   209  		}
   210  	}
   211  
   212  	p := reflect.Indirect(pattern)
   213  	v := reflect.Indirect(val)
   214  	if !p.IsValid() || !v.IsValid() {
   215  		return !p.IsValid() && !v.IsValid()
   216  	}
   217  
   218  	switch p.Kind() {
   219  	case reflect.Slice:
   220  		if p.Len() != v.Len() {
   221  			return false
   222  		}
   223  		for i := 0; i < p.Len(); i++ {
   224  			if !match(m, p.Index(i), v.Index(i)) {
   225  				return false
   226  			}
   227  		}
   228  		return true
   229  
   230  	case reflect.Struct:
   231  		for i := 0; i < p.NumField(); i++ {
   232  			if !match(m, p.Field(i), v.Field(i)) {
   233  				return false
   234  			}
   235  		}
   236  		return true
   237  
   238  	case reflect.Interface:
   239  		return match(m, p.Elem(), v.Elem())
   240  	}
   241  
   242  	// Handle token integers, etc.
   243  	return p.Interface() == v.Interface()
   244  }
   245  
   246  // subst returns a copy of pattern with values from m substituted in place
   247  // of wildcards and pos used as the position of tokens from the pattern.
   248  // if m == nil, subst returns a copy of pattern and doesn't change the line
   249  // number information.
   250  func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
   251  	if !pattern.IsValid() {
   252  		return reflect.Value{}
   253  	}
   254  
   255  	// Wildcard gets replaced with map value.
   256  	if m != nil && pattern.Type() == identType {
   257  		name := pattern.Interface().(*ast.Ident).Name
   258  		if isWildcard(name) {
   259  			if old, ok := m[name]; ok {
   260  				return subst(nil, old, reflect.Value{})
   261  			}
   262  		}
   263  	}
   264  
   265  	if pos.IsValid() && pattern.Type() == positionType {
   266  		// use new position only if old position was valid in the first place
   267  		if old := pattern.Interface().(token.Pos); !old.IsValid() {
   268  			return pattern
   269  		}
   270  		return pos
   271  	}
   272  
   273  	// Otherwise copy.
   274  	switch p := pattern; p.Kind() {
   275  	case reflect.Slice:
   276  		if p.IsNil() {
   277  			// Do not turn nil slices into empty slices. go/ast
   278  			// guarantees that certain lists will be nil if not
   279  			// populated.
   280  			return reflect.Zero(p.Type())
   281  		}
   282  		v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
   283  		for i := 0; i < p.Len(); i++ {
   284  			v.Index(i).Set(subst(m, p.Index(i), pos))
   285  		}
   286  		return v
   287  
   288  	case reflect.Struct:
   289  		v := reflect.New(p.Type()).Elem()
   290  		for i := 0; i < p.NumField(); i++ {
   291  			v.Field(i).Set(subst(m, p.Field(i), pos))
   292  		}
   293  		return v
   294  
   295  	case reflect.Pointer:
   296  		v := reflect.New(p.Type()).Elem()
   297  		if elem := p.Elem(); elem.IsValid() {
   298  			v.Set(subst(m, elem, pos).Addr())
   299  		}
   300  		return v
   301  
   302  	case reflect.Interface:
   303  		v := reflect.New(p.Type()).Elem()
   304  		if elem := p.Elem(); elem.IsValid() {
   305  			v.Set(subst(m, elem, pos))
   306  		}
   307  		return v
   308  	}
   309  
   310  	return pattern
   311  }