github.com/goproxy0/go@v0.0.0-20171111080102-49cc0c489d2c/src/cmd/fix/main.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/parser"
    14  	"go/scanner"
    15  	"go/token"
    16  	"io/ioutil"
    17  	"os"
    18  	"os/exec"
    19  	"path/filepath"
    20  	"runtime"
    21  	"sort"
    22  	"strings"
    23  )
    24  
    25  var (
    26  	fset     = token.NewFileSet()
    27  	exitCode = 0
    28  )
    29  
    30  var allowedRewrites = flag.String("r", "",
    31  	"restrict the rewrites to this comma-separated list")
    32  
    33  var forceRewrites = flag.String("force", "",
    34  	"force these fixes to run even if the code looks updated")
    35  
    36  var allowed, force map[string]bool
    37  
    38  var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
    39  
    40  // enable for debugging fix failures
    41  const debug = false // display incorrectly reformatted source and exit
    42  
    43  func usage() {
    44  	fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
    45  	flag.PrintDefaults()
    46  	fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
    47  	sort.Sort(byName(fixes))
    48  	for _, f := range fixes {
    49  		if f.disabled {
    50  			fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
    51  		} else {
    52  			fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
    53  		}
    54  		desc := strings.TrimSpace(f.desc)
    55  		desc = strings.Replace(desc, "\n", "\n\t", -1)
    56  		fmt.Fprintf(os.Stderr, "\t%s\n", desc)
    57  	}
    58  	os.Exit(2)
    59  }
    60  
    61  func main() {
    62  	flag.Usage = usage
    63  	flag.Parse()
    64  
    65  	sort.Sort(byDate(fixes))
    66  
    67  	if *allowedRewrites != "" {
    68  		allowed = make(map[string]bool)
    69  		for _, f := range strings.Split(*allowedRewrites, ",") {
    70  			allowed[f] = true
    71  		}
    72  	}
    73  
    74  	if *forceRewrites != "" {
    75  		force = make(map[string]bool)
    76  		for _, f := range strings.Split(*forceRewrites, ",") {
    77  			force[f] = true
    78  		}
    79  	}
    80  
    81  	if flag.NArg() == 0 {
    82  		if err := processFile("standard input", true); err != nil {
    83  			report(err)
    84  		}
    85  		os.Exit(exitCode)
    86  	}
    87  
    88  	for i := 0; i < flag.NArg(); i++ {
    89  		path := flag.Arg(i)
    90  		switch dir, err := os.Stat(path); {
    91  		case err != nil:
    92  			report(err)
    93  		case dir.IsDir():
    94  			walkDir(path)
    95  		default:
    96  			if err := processFile(path, false); err != nil {
    97  				report(err)
    98  			}
    99  		}
   100  	}
   101  
   102  	os.Exit(exitCode)
   103  }
   104  
   105  const parserMode = parser.ParseComments
   106  
   107  func gofmtFile(f *ast.File) ([]byte, error) {
   108  	var buf bytes.Buffer
   109  	if err := format.Node(&buf, fset, f); err != nil {
   110  		return nil, err
   111  	}
   112  	return buf.Bytes(), nil
   113  }
   114  
   115  func processFile(filename string, useStdin bool) error {
   116  	var f *os.File
   117  	var err error
   118  	var fixlog bytes.Buffer
   119  
   120  	if useStdin {
   121  		f = os.Stdin
   122  	} else {
   123  		f, err = os.Open(filename)
   124  		if err != nil {
   125  			return err
   126  		}
   127  		defer f.Close()
   128  	}
   129  
   130  	src, err := ioutil.ReadAll(f)
   131  	if err != nil {
   132  		return err
   133  	}
   134  
   135  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   136  	if err != nil {
   137  		return err
   138  	}
   139  
   140  	// Apply all fixes to file.
   141  	newFile := file
   142  	fixed := false
   143  	for _, fix := range fixes {
   144  		if allowed != nil && !allowed[fix.name] {
   145  			continue
   146  		}
   147  		if fix.disabled && !force[fix.name] {
   148  			continue
   149  		}
   150  		if fix.f(newFile) {
   151  			fixed = true
   152  			fmt.Fprintf(&fixlog, " %s", fix.name)
   153  
   154  			// AST changed.
   155  			// Print and parse, to update any missing scoping
   156  			// or position information for subsequent fixers.
   157  			newSrc, err := gofmtFile(newFile)
   158  			if err != nil {
   159  				return err
   160  			}
   161  			newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
   162  			if err != nil {
   163  				if debug {
   164  					fmt.Printf("%s", newSrc)
   165  					report(err)
   166  					os.Exit(exitCode)
   167  				}
   168  				return err
   169  			}
   170  		}
   171  	}
   172  	if !fixed {
   173  		return nil
   174  	}
   175  	fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
   176  
   177  	// Print AST.  We did that after each fix, so this appears
   178  	// redundant, but it is necessary to generate gofmt-compatible
   179  	// source code in a few cases. The official gofmt style is the
   180  	// output of the printer run on a standard AST generated by the parser,
   181  	// but the source we generated inside the loop above is the
   182  	// output of the printer run on a mangled AST generated by a fixer.
   183  	newSrc, err := gofmtFile(newFile)
   184  	if err != nil {
   185  		return err
   186  	}
   187  
   188  	if *doDiff {
   189  		data, err := diff(src, newSrc)
   190  		if err != nil {
   191  			return fmt.Errorf("computing diff: %s", err)
   192  		}
   193  		fmt.Printf("diff %s fixed/%s\n", filename, filename)
   194  		os.Stdout.Write(data)
   195  		return nil
   196  	}
   197  
   198  	if useStdin {
   199  		os.Stdout.Write(newSrc)
   200  		return nil
   201  	}
   202  
   203  	return ioutil.WriteFile(f.Name(), newSrc, 0)
   204  }
   205  
   206  var gofmtBuf bytes.Buffer
   207  
   208  func gofmt(n interface{}) string {
   209  	gofmtBuf.Reset()
   210  	if err := format.Node(&gofmtBuf, fset, n); err != nil {
   211  		return "<" + err.Error() + ">"
   212  	}
   213  	return gofmtBuf.String()
   214  }
   215  
   216  func report(err error) {
   217  	scanner.PrintError(os.Stderr, err)
   218  	exitCode = 2
   219  }
   220  
   221  func walkDir(path string) {
   222  	filepath.Walk(path, visitFile)
   223  }
   224  
   225  func visitFile(path string, f os.FileInfo, err error) error {
   226  	if err == nil && isGoFile(f) {
   227  		err = processFile(path, false)
   228  	}
   229  	if err != nil {
   230  		report(err)
   231  	}
   232  	return nil
   233  }
   234  
   235  func isGoFile(f os.FileInfo) bool {
   236  	// ignore non-Go files
   237  	name := f.Name()
   238  	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
   239  }
   240  
   241  func writeTempFile(dir, prefix string, data []byte) (string, error) {
   242  	file, err := ioutil.TempFile(dir, prefix)
   243  	if err != nil {
   244  		return "", err
   245  	}
   246  	_, err = file.Write(data)
   247  	if err1 := file.Close(); err == nil {
   248  		err = err1
   249  	}
   250  	if err != nil {
   251  		os.Remove(file.Name())
   252  		return "", err
   253  	}
   254  	return file.Name(), nil
   255  }
   256  
   257  func diff(b1, b2 []byte) (data []byte, err error) {
   258  	f1, err := writeTempFile("", "go-fix", b1)
   259  	if err != nil {
   260  		return
   261  	}
   262  	defer os.Remove(f1)
   263  
   264  	f2, err := writeTempFile("", "go-fix", b2)
   265  	if err != nil {
   266  		return
   267  	}
   268  	defer os.Remove(f2)
   269  
   270  	cmd := "diff"
   271  	if runtime.GOOS == "plan9" {
   272  		cmd = "/bin/ape/diff"
   273  	}
   274  
   275  	data, err = exec.Command(cmd, "-u", f1, f2).CombinedOutput()
   276  	if len(data) > 0 {
   277  		// diff exits with a non-zero status when the files don't match.
   278  		// Ignore that failure as long as we get output.
   279  		err = nil
   280  	}
   281  	return
   282  }