github.com/peggyl/go@v0.0.0-20151008231540-ae315999c2d5/src/cmd/fix/main.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/parser"
    14  	"go/scanner"
    15  	"go/token"
    16  	"io/ioutil"
    17  	"os"
    18  	"os/exec"
    19  	"path/filepath"
    20  	"sort"
    21  	"strings"
    22  )
    23  
    24  var (
    25  	fset     = token.NewFileSet()
    26  	exitCode = 0
    27  )
    28  
    29  var allowedRewrites = flag.String("r", "",
    30  	"restrict the rewrites to this comma-separated list")
    31  
    32  var forceRewrites = flag.String("force", "",
    33  	"force these fixes to run even if the code looks updated")
    34  
    35  var allowed, force map[string]bool
    36  
    37  var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
    38  
    39  // enable for debugging fix failures
    40  const debug = false // display incorrectly reformatted source and exit
    41  
    42  func usage() {
    43  	fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
    44  	flag.PrintDefaults()
    45  	fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
    46  	sort.Sort(byName(fixes))
    47  	for _, f := range fixes {
    48  		fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
    49  		desc := strings.TrimSpace(f.desc)
    50  		desc = strings.Replace(desc, "\n", "\n\t", -1)
    51  		fmt.Fprintf(os.Stderr, "\t%s\n", desc)
    52  	}
    53  	os.Exit(2)
    54  }
    55  
    56  func main() {
    57  	flag.Usage = usage
    58  	flag.Parse()
    59  
    60  	sort.Sort(byDate(fixes))
    61  
    62  	if *allowedRewrites != "" {
    63  		allowed = make(map[string]bool)
    64  		for _, f := range strings.Split(*allowedRewrites, ",") {
    65  			allowed[f] = true
    66  		}
    67  	}
    68  
    69  	if *forceRewrites != "" {
    70  		force = make(map[string]bool)
    71  		for _, f := range strings.Split(*forceRewrites, ",") {
    72  			force[f] = true
    73  		}
    74  	}
    75  
    76  	if flag.NArg() == 0 {
    77  		if err := processFile("standard input", true); err != nil {
    78  			report(err)
    79  		}
    80  		os.Exit(exitCode)
    81  	}
    82  
    83  	for i := 0; i < flag.NArg(); i++ {
    84  		path := flag.Arg(i)
    85  		switch dir, err := os.Stat(path); {
    86  		case err != nil:
    87  			report(err)
    88  		case dir.IsDir():
    89  			walkDir(path)
    90  		default:
    91  			if err := processFile(path, false); err != nil {
    92  				report(err)
    93  			}
    94  		}
    95  	}
    96  
    97  	os.Exit(exitCode)
    98  }
    99  
   100  const parserMode = parser.ParseComments
   101  
   102  func gofmtFile(f *ast.File) ([]byte, error) {
   103  	var buf bytes.Buffer
   104  	if err := format.Node(&buf, fset, f); err != nil {
   105  		return nil, err
   106  	}
   107  	return buf.Bytes(), nil
   108  }
   109  
   110  func processFile(filename string, useStdin bool) error {
   111  	var f *os.File
   112  	var err error
   113  	var fixlog bytes.Buffer
   114  
   115  	if useStdin {
   116  		f = os.Stdin
   117  	} else {
   118  		f, err = os.Open(filename)
   119  		if err != nil {
   120  			return err
   121  		}
   122  		defer f.Close()
   123  	}
   124  
   125  	src, err := ioutil.ReadAll(f)
   126  	if err != nil {
   127  		return err
   128  	}
   129  
   130  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   131  	if err != nil {
   132  		return err
   133  	}
   134  
   135  	// Apply all fixes to file.
   136  	newFile := file
   137  	fixed := false
   138  	for _, fix := range fixes {
   139  		if allowed != nil && !allowed[fix.name] {
   140  			continue
   141  		}
   142  		if fix.f(newFile) {
   143  			fixed = true
   144  			fmt.Fprintf(&fixlog, " %s", fix.name)
   145  
   146  			// AST changed.
   147  			// Print and parse, to update any missing scoping
   148  			// or position information for subsequent fixers.
   149  			newSrc, err := gofmtFile(newFile)
   150  			if err != nil {
   151  				return err
   152  			}
   153  			newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
   154  			if err != nil {
   155  				if debug {
   156  					fmt.Printf("%s", newSrc)
   157  					report(err)
   158  					os.Exit(exitCode)
   159  				}
   160  				return err
   161  			}
   162  		}
   163  	}
   164  	if !fixed {
   165  		return nil
   166  	}
   167  	fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
   168  
   169  	// Print AST.  We did that after each fix, so this appears
   170  	// redundant, but it is necessary to generate gofmt-compatible
   171  	// source code in a few cases.  The official gofmt style is the
   172  	// output of the printer run on a standard AST generated by the parser,
   173  	// but the source we generated inside the loop above is the
   174  	// output of the printer run on a mangled AST generated by a fixer.
   175  	newSrc, err := gofmtFile(newFile)
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	if *doDiff {
   181  		data, err := diff(src, newSrc)
   182  		if err != nil {
   183  			return fmt.Errorf("computing diff: %s", err)
   184  		}
   185  		fmt.Printf("diff %s fixed/%s\n", filename, filename)
   186  		os.Stdout.Write(data)
   187  		return nil
   188  	}
   189  
   190  	if useStdin {
   191  		os.Stdout.Write(newSrc)
   192  		return nil
   193  	}
   194  
   195  	return ioutil.WriteFile(f.Name(), newSrc, 0)
   196  }
   197  
   198  var gofmtBuf bytes.Buffer
   199  
   200  func gofmt(n interface{}) string {
   201  	gofmtBuf.Reset()
   202  	if err := format.Node(&gofmtBuf, fset, n); err != nil {
   203  		return "<" + err.Error() + ">"
   204  	}
   205  	return gofmtBuf.String()
   206  }
   207  
   208  func report(err error) {
   209  	scanner.PrintError(os.Stderr, err)
   210  	exitCode = 2
   211  }
   212  
   213  func walkDir(path string) {
   214  	filepath.Walk(path, visitFile)
   215  }
   216  
   217  func visitFile(path string, f os.FileInfo, err error) error {
   218  	if err == nil && isGoFile(f) {
   219  		err = processFile(path, false)
   220  	}
   221  	if err != nil {
   222  		report(err)
   223  	}
   224  	return nil
   225  }
   226  
   227  func isGoFile(f os.FileInfo) bool {
   228  	// ignore non-Go files
   229  	name := f.Name()
   230  	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
   231  }
   232  
   233  func diff(b1, b2 []byte) (data []byte, err error) {
   234  	f1, err := ioutil.TempFile("", "go-fix")
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  	defer os.Remove(f1.Name())
   239  	defer f1.Close()
   240  
   241  	f2, err := ioutil.TempFile("", "go-fix")
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  	defer os.Remove(f2.Name())
   246  	defer f2.Close()
   247  
   248  	f1.Write(b1)
   249  	f2.Write(b2)
   250  
   251  	data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
   252  	if len(data) > 0 {
   253  		// diff exits with a non-zero status when the files don't match.
   254  		// Ignore that failure as long as we get output.
   255  		err = nil
   256  	}
   257  	return
   258  }