github.com/ader1990/go@v0.0.0-20140630135419-8c24447fa791/src/cmd/gofmt/gofmt.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/parser"
    13  	"go/printer"
    14  	"go/scanner"
    15  	"go/token"
    16  	"io"
    17  	"io/ioutil"
    18  	"os"
    19  	"os/exec"
    20  	"path/filepath"
    21  	"runtime/pprof"
    22  	"strings"
    23  )
    24  
    25  var (
    26  	// main operation modes
    27  	list        = flag.Bool("l", false, "list files whose formatting differs from gofmt's")
    28  	write       = flag.Bool("w", false, "write result to (source) file instead of stdout")
    29  	rewriteRule = flag.String("r", "", "rewrite rule (e.g., 'a[b:len(a)] -> a[b:]')")
    30  	simplifyAST = flag.Bool("s", false, "simplify code")
    31  	doDiff      = flag.Bool("d", false, "display diffs instead of rewriting files")
    32  	allErrors   = flag.Bool("e", false, "report all errors (not just the first 10 on different lines)")
    33  
    34  	// debugging
    35  	cpuprofile = flag.String("cpuprofile", "", "write cpu profile to this file")
    36  )
    37  
    38  const (
    39  	tabWidth    = 8
    40  	printerMode = printer.UseSpaces | printer.TabIndent
    41  )
    42  
    43  var (
    44  	fileSet    = token.NewFileSet() // per process FileSet
    45  	exitCode   = 0
    46  	rewrite    func(*ast.File) *ast.File
    47  	parserMode parser.Mode
    48  )
    49  
    50  func report(err error) {
    51  	scanner.PrintError(os.Stderr, err)
    52  	exitCode = 2
    53  }
    54  
    55  func usage() {
    56  	fmt.Fprintf(os.Stderr, "usage: gofmt [flags] [path ...]\n")
    57  	flag.PrintDefaults()
    58  	os.Exit(2)
    59  }
    60  
    61  func initParserMode() {
    62  	parserMode = parser.ParseComments
    63  	if *allErrors {
    64  		parserMode |= parser.AllErrors
    65  	}
    66  }
    67  
    68  func isGoFile(f os.FileInfo) bool {
    69  	// ignore non-Go files
    70  	name := f.Name()
    71  	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
    72  }
    73  
    74  // If in == nil, the source is the contents of the file with the given filename.
    75  func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error {
    76  	if in == nil {
    77  		f, err := os.Open(filename)
    78  		if err != nil {
    79  			return err
    80  		}
    81  		defer f.Close()
    82  		in = f
    83  	}
    84  
    85  	src, err := ioutil.ReadAll(in)
    86  	if err != nil {
    87  		return err
    88  	}
    89  
    90  	file, adjust, err := parse(fileSet, filename, src, stdin)
    91  	if err != nil {
    92  		return err
    93  	}
    94  
    95  	if rewrite != nil {
    96  		if adjust == nil {
    97  			file = rewrite(file)
    98  		} else {
    99  			fmt.Fprintf(os.Stderr, "warning: rewrite ignored for incomplete programs\n")
   100  		}
   101  	}
   102  
   103  	ast.SortImports(fileSet, file)
   104  
   105  	if *simplifyAST {
   106  		simplify(file)
   107  	}
   108  
   109  	var buf bytes.Buffer
   110  	err = (&printer.Config{Mode: printerMode, Tabwidth: tabWidth}).Fprint(&buf, fileSet, file)
   111  	if err != nil {
   112  		return err
   113  	}
   114  	res := buf.Bytes()
   115  	if adjust != nil {
   116  		res = adjust(src, res)
   117  	}
   118  
   119  	if !bytes.Equal(src, res) {
   120  		// formatting has changed
   121  		if *list {
   122  			fmt.Fprintln(out, filename)
   123  		}
   124  		if *write {
   125  			err = ioutil.WriteFile(filename, res, 0)
   126  			if err != nil {
   127  				return err
   128  			}
   129  		}
   130  		if *doDiff {
   131  			data, err := diff(src, res)
   132  			if err != nil {
   133  				return fmt.Errorf("computing diff: %s", err)
   134  			}
   135  			fmt.Printf("diff %s gofmt/%s\n", filename, filename)
   136  			out.Write(data)
   137  		}
   138  	}
   139  
   140  	if !*list && !*write && !*doDiff {
   141  		_, err = out.Write(res)
   142  	}
   143  
   144  	return err
   145  }
   146  
   147  func visitFile(path string, f os.FileInfo, err error) error {
   148  	if err == nil && isGoFile(f) {
   149  		err = processFile(path, nil, os.Stdout, false)
   150  	}
   151  	if err != nil {
   152  		report(err)
   153  	}
   154  	return nil
   155  }
   156  
   157  func walkDir(path string) {
   158  	filepath.Walk(path, visitFile)
   159  }
   160  
   161  func main() {
   162  	// call gofmtMain in a separate function
   163  	// so that it can use defer and have them
   164  	// run before the exit.
   165  	gofmtMain()
   166  	os.Exit(exitCode)
   167  }
   168  
   169  func gofmtMain() {
   170  	flag.Usage = usage
   171  	flag.Parse()
   172  
   173  	if *cpuprofile != "" {
   174  		f, err := os.Create(*cpuprofile)
   175  		if err != nil {
   176  			fmt.Fprintf(os.Stderr, "creating cpu profile: %s\n", err)
   177  			exitCode = 2
   178  			return
   179  		}
   180  		defer f.Close()
   181  		pprof.StartCPUProfile(f)
   182  		defer pprof.StopCPUProfile()
   183  	}
   184  
   185  	initParserMode()
   186  	initRewrite()
   187  
   188  	if flag.NArg() == 0 {
   189  		if err := processFile("<standard input>", os.Stdin, os.Stdout, true); err != nil {
   190  			report(err)
   191  		}
   192  		return
   193  	}
   194  
   195  	for i := 0; i < flag.NArg(); i++ {
   196  		path := flag.Arg(i)
   197  		switch dir, err := os.Stat(path); {
   198  		case err != nil:
   199  			report(err)
   200  		case dir.IsDir():
   201  			walkDir(path)
   202  		default:
   203  			if err := processFile(path, nil, os.Stdout, false); err != nil {
   204  				report(err)
   205  			}
   206  		}
   207  	}
   208  }
   209  
   210  func diff(b1, b2 []byte) (data []byte, err error) {
   211  	f1, err := ioutil.TempFile("", "gofmt")
   212  	if err != nil {
   213  		return
   214  	}
   215  	defer os.Remove(f1.Name())
   216  	defer f1.Close()
   217  
   218  	f2, err := ioutil.TempFile("", "gofmt")
   219  	if err != nil {
   220  		return
   221  	}
   222  	defer os.Remove(f2.Name())
   223  	defer f2.Close()
   224  
   225  	f1.Write(b1)
   226  	f2.Write(b2)
   227  
   228  	data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
   229  	if len(data) > 0 {
   230  		// diff exits with a non-zero status when the files don't match.
   231  		// Ignore that failure as long as we get output.
   232  		err = nil
   233  	}
   234  	return
   235  
   236  }
   237  
   238  // parse parses src, which was read from filename,
   239  // as a Go source file or statement list.
   240  func parse(fset *token.FileSet, filename string, src []byte, stdin bool) (*ast.File, func(orig, src []byte) []byte, error) {
   241  	// Try as whole source file.
   242  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   243  	if err == nil {
   244  		return file, nil, nil
   245  	}
   246  	// If the error is that the source file didn't begin with a
   247  	// package line and this is standard input, fall through to
   248  	// try as a source fragment.  Stop and return on any other error.
   249  	if !stdin || !strings.Contains(err.Error(), "expected 'package'") {
   250  		return nil, nil, err
   251  	}
   252  
   253  	// If this is a declaration list, make it a source file
   254  	// by inserting a package clause.
   255  	// Insert using a ;, not a newline, so that the line numbers
   256  	// in psrc match the ones in src.
   257  	psrc := append([]byte("package p;"), src...)
   258  	file, err = parser.ParseFile(fset, filename, psrc, parserMode)
   259  	if err == nil {
   260  		adjust := func(orig, src []byte) []byte {
   261  			// Remove the package clause.
   262  			// Gofmt has turned the ; into a \n.
   263  			src = src[len("package p\n"):]
   264  			return matchSpace(orig, src)
   265  		}
   266  		return file, adjust, nil
   267  	}
   268  	// If the error is that the source file didn't begin with a
   269  	// declaration, fall through to try as a statement list.
   270  	// Stop and return on any other error.
   271  	if !strings.Contains(err.Error(), "expected declaration") {
   272  		return nil, nil, err
   273  	}
   274  
   275  	// If this is a statement list, make it a source file
   276  	// by inserting a package clause and turning the list
   277  	// into a function body.  This handles expressions too.
   278  	// Insert using a ;, not a newline, so that the line numbers
   279  	// in fsrc match the ones in src.
   280  	fsrc := append(append([]byte("package p; func _() {"), src...), '}')
   281  	file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
   282  	if err == nil {
   283  		adjust := func(orig, src []byte) []byte {
   284  			// Remove the wrapping.
   285  			// Gofmt has turned the ; into a \n\n.
   286  			src = src[len("package p\n\nfunc _() {"):]
   287  			src = src[:len(src)-len("}\n")]
   288  			// Gofmt has also indented the function body one level.
   289  			// Remove that indent.
   290  			src = bytes.Replace(src, []byte("\n\t"), []byte("\n"), -1)
   291  			return matchSpace(orig, src)
   292  		}
   293  		return file, adjust, nil
   294  	}
   295  
   296  	// Failed, and out of options.
   297  	return nil, nil, err
   298  }
   299  
   300  func cutSpace(b []byte) (before, middle, after []byte) {
   301  	i := 0
   302  	for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {
   303  		i++
   304  	}
   305  	j := len(b)
   306  	for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') {
   307  		j--
   308  	}
   309  	if i <= j {
   310  		return b[:i], b[i:j], b[j:]
   311  	}
   312  	return nil, nil, b[j:]
   313  }
   314  
   315  // matchSpace reformats src to use the same space context as orig.
   316  // 1) If orig begins with blank lines, matchSpace inserts them at the beginning of src.
   317  // 2) matchSpace copies the indentation of the first non-blank line in orig
   318  //    to every non-blank line in src.
   319  // 3) matchSpace copies the trailing space from orig and uses it in place
   320  //   of src's trailing space.
   321  func matchSpace(orig []byte, src []byte) []byte {
   322  	before, _, after := cutSpace(orig)
   323  	i := bytes.LastIndex(before, []byte{'\n'})
   324  	before, indent := before[:i+1], before[i+1:]
   325  
   326  	_, src, _ = cutSpace(src)
   327  
   328  	var b bytes.Buffer
   329  	b.Write(before)
   330  	for len(src) > 0 {
   331  		line := src
   332  		if i := bytes.IndexByte(line, '\n'); i >= 0 {
   333  			line, src = line[:i+1], line[i+1:]
   334  		} else {
   335  			src = nil
   336  		}
   337  		if len(line) > 0 && line[0] != '\n' { // not blank
   338  			b.Write(indent)
   339  		}
   340  		b.Write(line)
   341  	}
   342  	b.Write(after)
   343  	return b.Bytes()
   344  }