github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/internal/imports/imports.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:generate go run mkstdlib.go
     6  
     7  // Package imports implements a Go pretty-printer (like package "go/format")
     8  // that also adds or removes import statements as necessary.
     9  package imports
    10  
    11  import (
    12  	"bufio"
    13  	"bytes"
    14  	"fmt"
    15  	"go/ast"
    16  	"go/format"
    17  	"go/parser"
    18  	"go/printer"
    19  	"go/token"
    20  	"io"
    21  	"regexp"
    22  	"strconv"
    23  	"strings"
    24  
    25  	"golang.org/x/tools/go/ast/astutil"
    26  )
    27  
    28  // Options is golang.org/x/tools/imports.Options with extra internal-only options.
    29  type Options struct {
    30  	Env *ProcessEnv // The environment to use. Note: this contains the cached module and filesystem state.
    31  
    32  	// LocalPrefix is a comma-separated string of import path prefixes, which, if
    33  	// set, instructs Process to sort the import paths with the given prefixes
    34  	// into another group after 3rd-party packages.
    35  	LocalPrefix string
    36  
    37  	Fragment  bool // Accept fragment of a source file (no package statement)
    38  	AllErrors bool // Report all errors (not just the first 10 on different lines)
    39  
    40  	Comments  bool // Print comments (true if nil *Options provided)
    41  	TabIndent bool // Use tabs for indent (true if nil *Options provided)
    42  	TabWidth  int  // Tab width (8 if nil *Options provided)
    43  
    44  	FormatOnly bool // Disable the insertion and deletion of imports
    45  }
    46  
    47  // Process implements golang.org/x/tools/imports.Process with explicit context in opt.Env.
    48  func Process(filename string, src []byte, opt *Options) (formatted []byte, err error) {
    49  	fileSet := token.NewFileSet()
    50  	file, adjust, err := parse(fileSet, filename, src, opt)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	if !opt.FormatOnly {
    56  		if err := fixImports(fileSet, file, filename, opt.Env); err != nil {
    57  			return nil, err
    58  		}
    59  	}
    60  	return formatFile(fileSet, file, src, adjust, opt)
    61  }
    62  
    63  // FixImports returns a list of fixes to the imports that, when applied,
    64  // will leave the imports in the same state as Process. src and opt must
    65  // be specified.
    66  //
    67  // Note that filename's directory influences which imports can be chosen,
    68  // so it is important that filename be accurate.
    69  func FixImports(filename string, src []byte, opt *Options) (fixes []*ImportFix, err error) {
    70  	fileSet := token.NewFileSet()
    71  	file, _, err := parse(fileSet, filename, src, opt)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	return getFixes(fileSet, file, filename, opt.Env)
    77  }
    78  
    79  // ApplyFixes applies all of the fixes to the file and formats it. extraMode
    80  // is added in when parsing the file. src and opts must be specified, but no
    81  // env is needed.
    82  func ApplyFixes(fixes []*ImportFix, filename string, src []byte, opt *Options, extraMode parser.Mode) (formatted []byte, err error) {
    83  	// Don't use parse() -- we don't care about fragments or statement lists
    84  	// here, and we need to work with unparseable files.
    85  	fileSet := token.NewFileSet()
    86  	parserMode := parser.Mode(0)
    87  	if opt.Comments {
    88  		parserMode |= parser.ParseComments
    89  	}
    90  	if opt.AllErrors {
    91  		parserMode |= parser.AllErrors
    92  	}
    93  	parserMode |= extraMode
    94  
    95  	file, err := parser.ParseFile(fileSet, filename, src, parserMode)
    96  	if file == nil {
    97  		return nil, err
    98  	}
    99  
   100  	// Apply the fixes to the file.
   101  	apply(fileSet, file, fixes)
   102  
   103  	return formatFile(fileSet, file, src, nil, opt)
   104  }
   105  
   106  // formatFile formats the file syntax tree.
   107  // It may mutate the token.FileSet.
   108  //
   109  // If an adjust function is provided, it is called after formatting
   110  // with the original source (formatFile's src parameter) and the
   111  // formatted file, and returns the postpocessed result.
   112  func formatFile(fset *token.FileSet, file *ast.File, src []byte, adjust func(orig []byte, src []byte) []byte, opt *Options) ([]byte, error) {
   113  	mergeImports(file)
   114  	sortImports(opt.LocalPrefix, fset.File(file.Pos()), file)
   115  	var spacesBefore []string // import paths we need spaces before
   116  	for _, impSection := range astutil.Imports(fset, file) {
   117  		// Within each block of contiguous imports, see if any
   118  		// import lines are in different group numbers. If so,
   119  		// we'll need to put a space between them so it's
   120  		// compatible with gofmt.
   121  		lastGroup := -1
   122  		for _, importSpec := range impSection {
   123  			importPath, _ := strconv.Unquote(importSpec.Path.Value)
   124  			groupNum := importGroup(opt.LocalPrefix, importPath)
   125  			if groupNum != lastGroup && lastGroup != -1 {
   126  				spacesBefore = append(spacesBefore, importPath)
   127  			}
   128  			lastGroup = groupNum
   129  		}
   130  
   131  	}
   132  
   133  	printerMode := printer.UseSpaces
   134  	if opt.TabIndent {
   135  		printerMode |= printer.TabIndent
   136  	}
   137  	printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}
   138  
   139  	var buf bytes.Buffer
   140  	err := printConfig.Fprint(&buf, fset, file)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  	out := buf.Bytes()
   145  	if adjust != nil {
   146  		out = adjust(src, out)
   147  	}
   148  	if len(spacesBefore) > 0 {
   149  		out, err = addImportSpaces(bytes.NewReader(out), spacesBefore)
   150  		if err != nil {
   151  			return nil, err
   152  		}
   153  	}
   154  
   155  	out, err = format.Source(out)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	return out, nil
   160  }
   161  
   162  // parse parses src, which was read from filename,
   163  // as a Go source file or statement list.
   164  func parse(fset *token.FileSet, filename string, src []byte, opt *Options) (*ast.File, func(orig, src []byte) []byte, error) {
   165  	parserMode := parser.Mode(0)
   166  	if opt.Comments {
   167  		parserMode |= parser.ParseComments
   168  	}
   169  	if opt.AllErrors {
   170  		parserMode |= parser.AllErrors
   171  	}
   172  
   173  	// Try as whole source file.
   174  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   175  	if err == nil {
   176  		return file, nil, nil
   177  	}
   178  	// If the error is that the source file didn't begin with a
   179  	// package line and we accept fragmented input, fall through to
   180  	// try as a source fragment.  Stop and return on any other error.
   181  	if !opt.Fragment || !strings.Contains(err.Error(), "expected 'package'") {
   182  		return nil, nil, err
   183  	}
   184  
   185  	// If this is a declaration list, make it a source file
   186  	// by inserting a package clause.
   187  	// Insert using a ;, not a newline, so that parse errors are on
   188  	// the correct line.
   189  	const prefix = "package main;"
   190  	psrc := append([]byte(prefix), src...)
   191  	file, err = parser.ParseFile(fset, filename, psrc, parserMode)
   192  	if err == nil {
   193  		// Gofmt will turn the ; into a \n.
   194  		// Do that ourselves now and update the file contents,
   195  		// so that positions and line numbers are correct going forward.
   196  		psrc[len(prefix)-1] = '\n'
   197  		fset.File(file.Package).SetLinesForContent(psrc)
   198  
   199  		// If a main function exists, we will assume this is a main
   200  		// package and leave the file.
   201  		if containsMainFunc(file) {
   202  			return file, nil, nil
   203  		}
   204  
   205  		adjust := func(orig, src []byte) []byte {
   206  			// Remove the package clause.
   207  			src = src[len(prefix):]
   208  			return matchSpace(orig, src)
   209  		}
   210  		return file, adjust, nil
   211  	}
   212  	// If the error is that the source file didn't begin with a
   213  	// declaration, fall through to try as a statement list.
   214  	// Stop and return on any other error.
   215  	if !strings.Contains(err.Error(), "expected declaration") {
   216  		return nil, nil, err
   217  	}
   218  
   219  	// If this is a statement list, make it a source file
   220  	// by inserting a package clause and turning the list
   221  	// into a function body.  This handles expressions too.
   222  	// Insert using a ;, not a newline, so that the line numbers
   223  	// in fsrc match the ones in src.
   224  	fsrc := append(append([]byte("package p; func _() {"), src...), '}')
   225  	file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
   226  	if err == nil {
   227  		adjust := func(orig, src []byte) []byte {
   228  			// Remove the wrapping.
   229  			// Gofmt has turned the ; into a \n\n.
   230  			src = src[len("package p\n\nfunc _() {"):]
   231  			src = src[:len(src)-len("}\n")]
   232  			// Gofmt has also indented the function body one level.
   233  			// Remove that indent.
   234  			src = bytes.Replace(src, []byte("\n\t"), []byte("\n"), -1)
   235  			return matchSpace(orig, src)
   236  		}
   237  		return file, adjust, nil
   238  	}
   239  
   240  	// Failed, and out of options.
   241  	return nil, nil, err
   242  }
   243  
   244  // containsMainFunc checks if a file contains a function declaration with the
   245  // function signature 'func main()'
   246  func containsMainFunc(file *ast.File) bool {
   247  	for _, decl := range file.Decls {
   248  		if f, ok := decl.(*ast.FuncDecl); ok {
   249  			if f.Name.Name != "main" {
   250  				continue
   251  			}
   252  
   253  			if len(f.Type.Params.List) != 0 {
   254  				continue
   255  			}
   256  
   257  			if f.Type.Results != nil && len(f.Type.Results.List) != 0 {
   258  				continue
   259  			}
   260  
   261  			return true
   262  		}
   263  	}
   264  
   265  	return false
   266  }
   267  
   268  func cutSpace(b []byte) (before, middle, after []byte) {
   269  	i := 0
   270  	for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {
   271  		i++
   272  	}
   273  	j := len(b)
   274  	for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') {
   275  		j--
   276  	}
   277  	if i <= j {
   278  		return b[:i], b[i:j], b[j:]
   279  	}
   280  	return nil, nil, b[j:]
   281  }
   282  
   283  // matchSpace reformats src to use the same space context as orig.
   284  //  1. If orig begins with blank lines, matchSpace inserts them at the beginning of src.
   285  //  2. matchSpace copies the indentation of the first non-blank line in orig
   286  //     to every non-blank line in src.
   287  //  3. matchSpace copies the trailing space from orig and uses it in place
   288  //     of src's trailing space.
   289  func matchSpace(orig []byte, src []byte) []byte {
   290  	before, _, after := cutSpace(orig)
   291  	i := bytes.LastIndex(before, []byte{'\n'})
   292  	before, indent := before[:i+1], before[i+1:]
   293  
   294  	_, src, _ = cutSpace(src)
   295  
   296  	var b bytes.Buffer
   297  	b.Write(before)
   298  	for len(src) > 0 {
   299  		line := src
   300  		if i := bytes.IndexByte(line, '\n'); i >= 0 {
   301  			line, src = line[:i+1], line[i+1:]
   302  		} else {
   303  			src = nil
   304  		}
   305  		if len(line) > 0 && line[0] != '\n' { // not blank
   306  			b.Write(indent)
   307  		}
   308  		b.Write(line)
   309  	}
   310  	b.Write(after)
   311  	return b.Bytes()
   312  }
   313  
   314  var impLine = regexp.MustCompile(`^\s+(?:[\w\.]+\s+)?"(.+?)"`)
   315  
   316  func addImportSpaces(r io.Reader, breaks []string) ([]byte, error) {
   317  	var out bytes.Buffer
   318  	in := bufio.NewReader(r)
   319  	inImports := false
   320  	done := false
   321  	for {
   322  		s, err := in.ReadString('\n')
   323  		if err == io.EOF {
   324  			break
   325  		} else if err != nil {
   326  			return nil, err
   327  		}
   328  
   329  		if !inImports && !done && strings.HasPrefix(s, "import") {
   330  			inImports = true
   331  		}
   332  		if inImports && (strings.HasPrefix(s, "var") ||
   333  			strings.HasPrefix(s, "func") ||
   334  			strings.HasPrefix(s, "const") ||
   335  			strings.HasPrefix(s, "type")) {
   336  			done = true
   337  			inImports = false
   338  		}
   339  		if inImports && len(breaks) > 0 {
   340  			if m := impLine.FindStringSubmatch(s); m != nil {
   341  				if m[1] == breaks[0] {
   342  					out.WriteByte('\n')
   343  					breaks = breaks[1:]
   344  				}
   345  			}
   346  		}
   347  
   348  		fmt.Fprint(&out, s)
   349  	}
   350  	return out.Bytes(), nil
   351  }