github.com/riscv/riscv-go@v0.0.0-20200123204226-124ebd6fcc8e/src/cmd/fix/main.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/parser"
    14  	"go/scanner"
    15  	"go/token"
    16  	"io/ioutil"
    17  	"os"
    18  	"os/exec"
    19  	"path/filepath"
    20  	"sort"
    21  	"strings"
    22  )
    23  
    24  var (
    25  	fset     = token.NewFileSet()
    26  	exitCode = 0
    27  )
    28  
    29  var allowedRewrites = flag.String("r", "",
    30  	"restrict the rewrites to this comma-separated list")
    31  
    32  var forceRewrites = flag.String("force", "",
    33  	"force these fixes to run even if the code looks updated")
    34  
    35  var allowed, force map[string]bool
    36  
    37  var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
    38  
    39  // enable for debugging fix failures
    40  const debug = false // display incorrectly reformatted source and exit
    41  
    42  func usage() {
    43  	fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
    44  	flag.PrintDefaults()
    45  	fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
    46  	sort.Sort(byName(fixes))
    47  	for _, f := range fixes {
    48  		if f.disabled {
    49  			fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
    50  		} else {
    51  			fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
    52  		}
    53  		desc := strings.TrimSpace(f.desc)
    54  		desc = strings.Replace(desc, "\n", "\n\t", -1)
    55  		fmt.Fprintf(os.Stderr, "\t%s\n", desc)
    56  	}
    57  	os.Exit(2)
    58  }
    59  
    60  func main() {
    61  	flag.Usage = usage
    62  	flag.Parse()
    63  
    64  	sort.Sort(byDate(fixes))
    65  
    66  	if *allowedRewrites != "" {
    67  		allowed = make(map[string]bool)
    68  		for _, f := range strings.Split(*allowedRewrites, ",") {
    69  			allowed[f] = true
    70  		}
    71  	}
    72  
    73  	if *forceRewrites != "" {
    74  		force = make(map[string]bool)
    75  		for _, f := range strings.Split(*forceRewrites, ",") {
    76  			force[f] = true
    77  		}
    78  	}
    79  
    80  	if flag.NArg() == 0 {
    81  		if err := processFile("standard input", true); err != nil {
    82  			report(err)
    83  		}
    84  		os.Exit(exitCode)
    85  	}
    86  
    87  	for i := 0; i < flag.NArg(); i++ {
    88  		path := flag.Arg(i)
    89  		switch dir, err := os.Stat(path); {
    90  		case err != nil:
    91  			report(err)
    92  		case dir.IsDir():
    93  			walkDir(path)
    94  		default:
    95  			if err := processFile(path, false); err != nil {
    96  				report(err)
    97  			}
    98  		}
    99  	}
   100  
   101  	os.Exit(exitCode)
   102  }
   103  
   104  const parserMode = parser.ParseComments
   105  
   106  func gofmtFile(f *ast.File) ([]byte, error) {
   107  	var buf bytes.Buffer
   108  	if err := format.Node(&buf, fset, f); err != nil {
   109  		return nil, err
   110  	}
   111  	return buf.Bytes(), nil
   112  }
   113  
   114  func processFile(filename string, useStdin bool) error {
   115  	var f *os.File
   116  	var err error
   117  	var fixlog bytes.Buffer
   118  
   119  	if useStdin {
   120  		f = os.Stdin
   121  	} else {
   122  		f, err = os.Open(filename)
   123  		if err != nil {
   124  			return err
   125  		}
   126  		defer f.Close()
   127  	}
   128  
   129  	src, err := ioutil.ReadAll(f)
   130  	if err != nil {
   131  		return err
   132  	}
   133  
   134  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   135  	if err != nil {
   136  		return err
   137  	}
   138  
   139  	// Apply all fixes to file.
   140  	newFile := file
   141  	fixed := false
   142  	for _, fix := range fixes {
   143  		if allowed != nil && !allowed[fix.name] {
   144  			continue
   145  		}
   146  		if fix.disabled && !force[fix.name] {
   147  			continue
   148  		}
   149  		if fix.f(newFile) {
   150  			fixed = true
   151  			fmt.Fprintf(&fixlog, " %s", fix.name)
   152  
   153  			// AST changed.
   154  			// Print and parse, to update any missing scoping
   155  			// or position information for subsequent fixers.
   156  			newSrc, err := gofmtFile(newFile)
   157  			if err != nil {
   158  				return err
   159  			}
   160  			newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
   161  			if err != nil {
   162  				if debug {
   163  					fmt.Printf("%s", newSrc)
   164  					report(err)
   165  					os.Exit(exitCode)
   166  				}
   167  				return err
   168  			}
   169  		}
   170  	}
   171  	if !fixed {
   172  		return nil
   173  	}
   174  	fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
   175  
   176  	// Print AST.  We did that after each fix, so this appears
   177  	// redundant, but it is necessary to generate gofmt-compatible
   178  	// source code in a few cases. The official gofmt style is the
   179  	// output of the printer run on a standard AST generated by the parser,
   180  	// but the source we generated inside the loop above is the
   181  	// output of the printer run on a mangled AST generated by a fixer.
   182  	newSrc, err := gofmtFile(newFile)
   183  	if err != nil {
   184  		return err
   185  	}
   186  
   187  	if *doDiff {
   188  		data, err := diff(src, newSrc)
   189  		if err != nil {
   190  			return fmt.Errorf("computing diff: %s", err)
   191  		}
   192  		fmt.Printf("diff %s fixed/%s\n", filename, filename)
   193  		os.Stdout.Write(data)
   194  		return nil
   195  	}
   196  
   197  	if useStdin {
   198  		os.Stdout.Write(newSrc)
   199  		return nil
   200  	}
   201  
   202  	return ioutil.WriteFile(f.Name(), newSrc, 0)
   203  }
   204  
   205  var gofmtBuf bytes.Buffer
   206  
   207  func gofmt(n interface{}) string {
   208  	gofmtBuf.Reset()
   209  	if err := format.Node(&gofmtBuf, fset, n); err != nil {
   210  		return "<" + err.Error() + ">"
   211  	}
   212  	return gofmtBuf.String()
   213  }
   214  
   215  func report(err error) {
   216  	scanner.PrintError(os.Stderr, err)
   217  	exitCode = 2
   218  }
   219  
   220  func walkDir(path string) {
   221  	filepath.Walk(path, visitFile)
   222  }
   223  
   224  func visitFile(path string, f os.FileInfo, err error) error {
   225  	if err == nil && isGoFile(f) {
   226  		err = processFile(path, false)
   227  	}
   228  	if err != nil {
   229  		report(err)
   230  	}
   231  	return nil
   232  }
   233  
   234  func isGoFile(f os.FileInfo) bool {
   235  	// ignore non-Go files
   236  	name := f.Name()
   237  	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
   238  }
   239  
   240  func diff(b1, b2 []byte) (data []byte, err error) {
   241  	f1, err := ioutil.TempFile("", "go-fix")
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  	defer os.Remove(f1.Name())
   246  	defer f1.Close()
   247  
   248  	f2, err := ioutil.TempFile("", "go-fix")
   249  	if err != nil {
   250  		return nil, err
   251  	}
   252  	defer os.Remove(f2.Name())
   253  	defer f2.Close()
   254  
   255  	f1.Write(b1)
   256  	f2.Write(b2)
   257  
   258  	data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
   259  	if len(data) > 0 {
   260  		// diff exits with a non-zero status when the files don't match.
   261  		// Ignore that failure as long as we get output.
   262  		err = nil
   263  	}
   264  	return
   265  }