github.com/spotify/syslog-redirector-golang@v0.0.0-20140320174030-4859f03d829a/src/cmd/gofmt/rewrite.go (about)

     1  // Copyright 2009 The Go Authors.  All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/parser"
    11  	"go/token"
    12  	"os"
    13  	"reflect"
    14  	"strings"
    15  	"unicode"
    16  	"unicode/utf8"
    17  )
    18  
    19  func initRewrite() {
    20  	if *rewriteRule == "" {
    21  		rewrite = nil // disable any previous rewrite
    22  		return
    23  	}
    24  	f := strings.Split(*rewriteRule, "->")
    25  	if len(f) != 2 {
    26  		fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
    27  		os.Exit(2)
    28  	}
    29  	pattern := parseExpr(f[0], "pattern")
    30  	replace := parseExpr(f[1], "replacement")
    31  	rewrite = func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) }
    32  }
    33  
    34  // parseExpr parses s as an expression.
    35  // It might make sense to expand this to allow statement patterns,
    36  // but there are problems with preserving formatting and also
    37  // with what a wildcard for a statement looks like.
    38  func parseExpr(s, what string) ast.Expr {
    39  	x, err := parser.ParseExpr(s)
    40  	if err != nil {
    41  		fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
    42  		os.Exit(2)
    43  	}
    44  	return x
    45  }
    46  
    47  // Keep this function for debugging.
    48  /*
    49  func dump(msg string, val reflect.Value) {
    50  	fmt.Printf("%s:\n", msg)
    51  	ast.Print(fset, val.Interface())
    52  	fmt.Println()
    53  }
    54  */
    55  
    56  // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
    57  func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
    58  	cmap := ast.NewCommentMap(fileSet, p, p.Comments)
    59  	m := make(map[string]reflect.Value)
    60  	pat := reflect.ValueOf(pattern)
    61  	repl := reflect.ValueOf(replace)
    62  	var f func(val reflect.Value) reflect.Value // f is recursive
    63  	f = func(val reflect.Value) reflect.Value {
    64  		// don't bother if val is invalid to start with
    65  		if !val.IsValid() {
    66  			return reflect.Value{}
    67  		}
    68  		for k := range m {
    69  			delete(m, k)
    70  		}
    71  		val = apply(f, val)
    72  		if match(m, pat, val) {
    73  			val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
    74  		}
    75  		return val
    76  	}
    77  	r := apply(f, reflect.ValueOf(p)).Interface().(*ast.File)
    78  	r.Comments = cmap.Filter(r).Comments() // recreate comments list
    79  	return r
    80  }
    81  
    82  // setValue is a wrapper for x.SetValue(y); it protects
    83  // the caller from panics if x cannot be changed to y.
    84  func setValue(x, y reflect.Value) {
    85  	// don't bother if y is invalid to start with
    86  	if !y.IsValid() {
    87  		return
    88  	}
    89  	defer func() {
    90  		if x := recover(); x != nil {
    91  			if s, ok := x.(string); ok &&
    92  				(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
    93  				// x cannot be set to y - ignore this rewrite
    94  				return
    95  			}
    96  			panic(x)
    97  		}
    98  	}()
    99  	x.Set(y)
   100  }
   101  
   102  // Values/types for special cases.
   103  var (
   104  	objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
   105  	scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))
   106  
   107  	identType     = reflect.TypeOf((*ast.Ident)(nil))
   108  	objectPtrType = reflect.TypeOf((*ast.Object)(nil))
   109  	positionType  = reflect.TypeOf(token.NoPos)
   110  	callExprType  = reflect.TypeOf((*ast.CallExpr)(nil))
   111  	scopePtrType  = reflect.TypeOf((*ast.Scope)(nil))
   112  )
   113  
   114  // apply replaces each AST field x in val with f(x), returning val.
   115  // To avoid extra conversions, f operates on the reflect.Value form.
   116  func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
   117  	if !val.IsValid() {
   118  		return reflect.Value{}
   119  	}
   120  
   121  	// *ast.Objects introduce cycles and are likely incorrect after
   122  	// rewrite; don't follow them but replace with nil instead
   123  	if val.Type() == objectPtrType {
   124  		return objectPtrNil
   125  	}
   126  
   127  	// similarly for scopes: they are likely incorrect after a rewrite;
   128  	// replace them with nil
   129  	if val.Type() == scopePtrType {
   130  		return scopePtrNil
   131  	}
   132  
   133  	switch v := reflect.Indirect(val); v.Kind() {
   134  	case reflect.Slice:
   135  		for i := 0; i < v.Len(); i++ {
   136  			e := v.Index(i)
   137  			setValue(e, f(e))
   138  		}
   139  	case reflect.Struct:
   140  		for i := 0; i < v.NumField(); i++ {
   141  			e := v.Field(i)
   142  			setValue(e, f(e))
   143  		}
   144  	case reflect.Interface:
   145  		e := v.Elem()
   146  		setValue(v, f(e))
   147  	}
   148  	return val
   149  }
   150  
   151  func isWildcard(s string) bool {
   152  	rune, size := utf8.DecodeRuneInString(s)
   153  	return size == len(s) && unicode.IsLower(rune)
   154  }
   155  
   156  // match returns true if pattern matches val,
   157  // recording wildcard submatches in m.
   158  // If m == nil, match checks whether pattern == val.
   159  func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
   160  	// Wildcard matches any expression.  If it appears multiple
   161  	// times in the pattern, it must match the same expression
   162  	// each time.
   163  	if m != nil && pattern.IsValid() && pattern.Type() == identType {
   164  		name := pattern.Interface().(*ast.Ident).Name
   165  		if isWildcard(name) && val.IsValid() {
   166  			// wildcards only match valid (non-nil) expressions.
   167  			if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
   168  				if old, ok := m[name]; ok {
   169  					return match(nil, old, val)
   170  				}
   171  				m[name] = val
   172  				return true
   173  			}
   174  		}
   175  	}
   176  
   177  	// Otherwise, pattern and val must match recursively.
   178  	if !pattern.IsValid() || !val.IsValid() {
   179  		return !pattern.IsValid() && !val.IsValid()
   180  	}
   181  	if pattern.Type() != val.Type() {
   182  		return false
   183  	}
   184  
   185  	// Special cases.
   186  	switch pattern.Type() {
   187  	case identType:
   188  		// For identifiers, only the names need to match
   189  		// (and none of the other *ast.Object information).
   190  		// This is a common case, handle it all here instead
   191  		// of recursing down any further via reflection.
   192  		p := pattern.Interface().(*ast.Ident)
   193  		v := val.Interface().(*ast.Ident)
   194  		return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
   195  	case objectPtrType, positionType:
   196  		// object pointers and token positions always match
   197  		return true
   198  	case callExprType:
   199  		// For calls, the Ellipsis fields (token.Position) must
   200  		// match since that is how f(x) and f(x...) are different.
   201  		// Check them here but fall through for the remaining fields.
   202  		p := pattern.Interface().(*ast.CallExpr)
   203  		v := val.Interface().(*ast.CallExpr)
   204  		if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
   205  			return false
   206  		}
   207  	}
   208  
   209  	p := reflect.Indirect(pattern)
   210  	v := reflect.Indirect(val)
   211  	if !p.IsValid() || !v.IsValid() {
   212  		return !p.IsValid() && !v.IsValid()
   213  	}
   214  
   215  	switch p.Kind() {
   216  	case reflect.Slice:
   217  		if p.Len() != v.Len() {
   218  			return false
   219  		}
   220  		for i := 0; i < p.Len(); i++ {
   221  			if !match(m, p.Index(i), v.Index(i)) {
   222  				return false
   223  			}
   224  		}
   225  		return true
   226  
   227  	case reflect.Struct:
   228  		if p.NumField() != v.NumField() {
   229  			return false
   230  		}
   231  		for i := 0; i < p.NumField(); i++ {
   232  			if !match(m, p.Field(i), v.Field(i)) {
   233  				return false
   234  			}
   235  		}
   236  		return true
   237  
   238  	case reflect.Interface:
   239  		return match(m, p.Elem(), v.Elem())
   240  	}
   241  
   242  	// Handle token integers, etc.
   243  	return p.Interface() == v.Interface()
   244  }
   245  
   246  // subst returns a copy of pattern with values from m substituted in place
   247  // of wildcards and pos used as the position of tokens from the pattern.
   248  // if m == nil, subst returns a copy of pattern and doesn't change the line
   249  // number information.
   250  func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
   251  	if !pattern.IsValid() {
   252  		return reflect.Value{}
   253  	}
   254  
   255  	// Wildcard gets replaced with map value.
   256  	if m != nil && pattern.Type() == identType {
   257  		name := pattern.Interface().(*ast.Ident).Name
   258  		if isWildcard(name) {
   259  			if old, ok := m[name]; ok {
   260  				return subst(nil, old, reflect.Value{})
   261  			}
   262  		}
   263  	}
   264  
   265  	if pos.IsValid() && pattern.Type() == positionType {
   266  		// use new position only if old position was valid in the first place
   267  		if old := pattern.Interface().(token.Pos); !old.IsValid() {
   268  			return pattern
   269  		}
   270  		return pos
   271  	}
   272  
   273  	// Otherwise copy.
   274  	switch p := pattern; p.Kind() {
   275  	case reflect.Slice:
   276  		v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
   277  		for i := 0; i < p.Len(); i++ {
   278  			v.Index(i).Set(subst(m, p.Index(i), pos))
   279  		}
   280  		return v
   281  
   282  	case reflect.Struct:
   283  		v := reflect.New(p.Type()).Elem()
   284  		for i := 0; i < p.NumField(); i++ {
   285  			v.Field(i).Set(subst(m, p.Field(i), pos))
   286  		}
   287  		return v
   288  
   289  	case reflect.Ptr:
   290  		v := reflect.New(p.Type()).Elem()
   291  		if elem := p.Elem(); elem.IsValid() {
   292  			v.Set(subst(m, elem, pos).Addr())
   293  		}
   294  		return v
   295  
   296  	case reflect.Interface:
   297  		v := reflect.New(p.Type()).Elem()
   298  		if elem := p.Elem(); elem.IsValid() {
   299  			v.Set(subst(m, elem, pos))
   300  		}
   301  		return v
   302  	}
   303  
   304  	return pattern
   305  }