golang.org/x/tools@v0.21.0/internal/imports/imports.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package imports implements a Go pretty-printer (like package "go/format")
     6  // that also adds or removes import statements as necessary.
     7  package imports
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"fmt"
    14  	"go/ast"
    15  	"go/format"
    16  	"go/parser"
    17  	"go/printer"
    18  	"go/token"
    19  	"io"
    20  	"regexp"
    21  	"strconv"
    22  	"strings"
    23  
    24  	"golang.org/x/tools/go/ast/astutil"
    25  	"golang.org/x/tools/internal/event"
    26  )
    27  
    28  // Options is golang.org/x/tools/imports.Options with extra internal-only options.
    29  type Options struct {
    30  	Env *ProcessEnv // The environment to use. Note: this contains the cached module and filesystem state.
    31  
    32  	// LocalPrefix is a comma-separated string of import path prefixes, which, if
    33  	// set, instructs Process to sort the import paths with the given prefixes
    34  	// into another group after 3rd-party packages.
    35  	LocalPrefix string
    36  
    37  	Fragment  bool // Accept fragment of a source file (no package statement)
    38  	AllErrors bool // Report all errors (not just the first 10 on different lines)
    39  
    40  	Comments  bool // Print comments (true if nil *Options provided)
    41  	TabIndent bool // Use tabs for indent (true if nil *Options provided)
    42  	TabWidth  int  // Tab width (8 if nil *Options provided)
    43  
    44  	FormatOnly bool // Disable the insertion and deletion of imports
    45  }
    46  
    47  // Process implements golang.org/x/tools/imports.Process with explicit context in opt.Env.
    48  func Process(filename string, src []byte, opt *Options) (formatted []byte, err error) {
    49  	fileSet := token.NewFileSet()
    50  	file, adjust, err := parse(fileSet, filename, src, opt)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	if !opt.FormatOnly {
    56  		if err := fixImports(fileSet, file, filename, opt.Env); err != nil {
    57  			return nil, err
    58  		}
    59  	}
    60  	return formatFile(fileSet, file, src, adjust, opt)
    61  }
    62  
    63  // FixImports returns a list of fixes to the imports that, when applied,
    64  // will leave the imports in the same state as Process. src and opt must
    65  // be specified.
    66  //
    67  // Note that filename's directory influences which imports can be chosen,
    68  // so it is important that filename be accurate.
    69  func FixImports(ctx context.Context, filename string, src []byte, opt *Options) (fixes []*ImportFix, err error) {
    70  	ctx, done := event.Start(ctx, "imports.FixImports")
    71  	defer done()
    72  
    73  	fileSet := token.NewFileSet()
    74  	file, _, err := parse(fileSet, filename, src, opt)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	return getFixes(ctx, fileSet, file, filename, opt.Env)
    80  }
    81  
    82  // ApplyFixes applies all of the fixes to the file and formats it. extraMode
    83  // is added in when parsing the file. src and opts must be specified, but no
    84  // env is needed.
    85  func ApplyFixes(fixes []*ImportFix, filename string, src []byte, opt *Options, extraMode parser.Mode) (formatted []byte, err error) {
    86  	// Don't use parse() -- we don't care about fragments or statement lists
    87  	// here, and we need to work with unparseable files.
    88  	fileSet := token.NewFileSet()
    89  	parserMode := parser.Mode(0)
    90  	if opt.Comments {
    91  		parserMode |= parser.ParseComments
    92  	}
    93  	if opt.AllErrors {
    94  		parserMode |= parser.AllErrors
    95  	}
    96  	parserMode |= extraMode
    97  
    98  	file, err := parser.ParseFile(fileSet, filename, src, parserMode)
    99  	if file == nil {
   100  		return nil, err
   101  	}
   102  
   103  	// Apply the fixes to the file.
   104  	apply(fileSet, file, fixes)
   105  
   106  	return formatFile(fileSet, file, src, nil, opt)
   107  }
   108  
   109  // formatFile formats the file syntax tree.
   110  // It may mutate the token.FileSet and the ast.File.
   111  //
   112  // If an adjust function is provided, it is called after formatting
   113  // with the original source (formatFile's src parameter) and the
   114  // formatted file, and returns the postpocessed result.
   115  func formatFile(fset *token.FileSet, file *ast.File, src []byte, adjust func(orig []byte, src []byte) []byte, opt *Options) ([]byte, error) {
   116  	mergeImports(file)
   117  	sortImports(opt.LocalPrefix, fset.File(file.Pos()), file)
   118  	var spacesBefore []string // import paths we need spaces before
   119  	for _, impSection := range astutil.Imports(fset, file) {
   120  		// Within each block of contiguous imports, see if any
   121  		// import lines are in different group numbers. If so,
   122  		// we'll need to put a space between them so it's
   123  		// compatible with gofmt.
   124  		lastGroup := -1
   125  		for _, importSpec := range impSection {
   126  			importPath, _ := strconv.Unquote(importSpec.Path.Value)
   127  			groupNum := importGroup(opt.LocalPrefix, importPath)
   128  			if groupNum != lastGroup && lastGroup != -1 {
   129  				spacesBefore = append(spacesBefore, importPath)
   130  			}
   131  			lastGroup = groupNum
   132  		}
   133  
   134  	}
   135  
   136  	printerMode := printer.UseSpaces
   137  	if opt.TabIndent {
   138  		printerMode |= printer.TabIndent
   139  	}
   140  	printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}
   141  
   142  	var buf bytes.Buffer
   143  	err := printConfig.Fprint(&buf, fset, file)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	out := buf.Bytes()
   148  	if adjust != nil {
   149  		out = adjust(src, out)
   150  	}
   151  	if len(spacesBefore) > 0 {
   152  		out, err = addImportSpaces(bytes.NewReader(out), spacesBefore)
   153  		if err != nil {
   154  			return nil, err
   155  		}
   156  	}
   157  
   158  	out, err = format.Source(out)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	return out, nil
   163  }
   164  
   165  // parse parses src, which was read from filename,
   166  // as a Go source file or statement list.
   167  func parse(fset *token.FileSet, filename string, src []byte, opt *Options) (*ast.File, func(orig, src []byte) []byte, error) {
   168  	parserMode := parser.Mode(0)
   169  	if opt.Comments {
   170  		parserMode |= parser.ParseComments
   171  	}
   172  	if opt.AllErrors {
   173  		parserMode |= parser.AllErrors
   174  	}
   175  
   176  	// Try as whole source file.
   177  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   178  	if err == nil {
   179  		return file, nil, nil
   180  	}
   181  	// If the error is that the source file didn't begin with a
   182  	// package line and we accept fragmented input, fall through to
   183  	// try as a source fragment.  Stop and return on any other error.
   184  	if !opt.Fragment || !strings.Contains(err.Error(), "expected 'package'") {
   185  		return nil, nil, err
   186  	}
   187  
   188  	// If this is a declaration list, make it a source file
   189  	// by inserting a package clause.
   190  	// Insert using a ;, not a newline, so that parse errors are on
   191  	// the correct line.
   192  	const prefix = "package main;"
   193  	psrc := append([]byte(prefix), src...)
   194  	file, err = parser.ParseFile(fset, filename, psrc, parserMode)
   195  	if err == nil {
   196  		// Gofmt will turn the ; into a \n.
   197  		// Do that ourselves now and update the file contents,
   198  		// so that positions and line numbers are correct going forward.
   199  		psrc[len(prefix)-1] = '\n'
   200  		fset.File(file.Package).SetLinesForContent(psrc)
   201  
   202  		// If a main function exists, we will assume this is a main
   203  		// package and leave the file.
   204  		if containsMainFunc(file) {
   205  			return file, nil, nil
   206  		}
   207  
   208  		adjust := func(orig, src []byte) []byte {
   209  			// Remove the package clause.
   210  			src = src[len(prefix):]
   211  			return matchSpace(orig, src)
   212  		}
   213  		return file, adjust, nil
   214  	}
   215  	// If the error is that the source file didn't begin with a
   216  	// declaration, fall through to try as a statement list.
   217  	// Stop and return on any other error.
   218  	if !strings.Contains(err.Error(), "expected declaration") {
   219  		return nil, nil, err
   220  	}
   221  
   222  	// If this is a statement list, make it a source file
   223  	// by inserting a package clause and turning the list
   224  	// into a function body.  This handles expressions too.
   225  	// Insert using a ;, not a newline, so that the line numbers
   226  	// in fsrc match the ones in src.
   227  	fsrc := append(append([]byte("package p; func _() {"), src...), '}')
   228  	file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
   229  	if err == nil {
   230  		adjust := func(orig, src []byte) []byte {
   231  			// Remove the wrapping.
   232  			// Gofmt has turned the ; into a \n\n.
   233  			src = src[len("package p\n\nfunc _() {"):]
   234  			src = src[:len(src)-len("}\n")]
   235  			// Gofmt has also indented the function body one level.
   236  			// Remove that indent.
   237  			src = bytes.ReplaceAll(src, []byte("\n\t"), []byte("\n"))
   238  			return matchSpace(orig, src)
   239  		}
   240  		return file, adjust, nil
   241  	}
   242  
   243  	// Failed, and out of options.
   244  	return nil, nil, err
   245  }
   246  
   247  // containsMainFunc checks if a file contains a function declaration with the
   248  // function signature 'func main()'
   249  func containsMainFunc(file *ast.File) bool {
   250  	for _, decl := range file.Decls {
   251  		if f, ok := decl.(*ast.FuncDecl); ok {
   252  			if f.Name.Name != "main" {
   253  				continue
   254  			}
   255  
   256  			if len(f.Type.Params.List) != 0 {
   257  				continue
   258  			}
   259  
   260  			if f.Type.Results != nil && len(f.Type.Results.List) != 0 {
   261  				continue
   262  			}
   263  
   264  			return true
   265  		}
   266  	}
   267  
   268  	return false
   269  }
   270  
   271  func cutSpace(b []byte) (before, middle, after []byte) {
   272  	i := 0
   273  	for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {
   274  		i++
   275  	}
   276  	j := len(b)
   277  	for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') {
   278  		j--
   279  	}
   280  	if i <= j {
   281  		return b[:i], b[i:j], b[j:]
   282  	}
   283  	return nil, nil, b[j:]
   284  }
   285  
   286  // matchSpace reformats src to use the same space context as orig.
   287  //  1. If orig begins with blank lines, matchSpace inserts them at the beginning of src.
   288  //  2. matchSpace copies the indentation of the first non-blank line in orig
   289  //     to every non-blank line in src.
   290  //  3. matchSpace copies the trailing space from orig and uses it in place
   291  //     of src's trailing space.
   292  func matchSpace(orig []byte, src []byte) []byte {
   293  	before, _, after := cutSpace(orig)
   294  	i := bytes.LastIndex(before, []byte{'\n'})
   295  	before, indent := before[:i+1], before[i+1:]
   296  
   297  	_, src, _ = cutSpace(src)
   298  
   299  	var b bytes.Buffer
   300  	b.Write(before)
   301  	for len(src) > 0 {
   302  		line := src
   303  		if i := bytes.IndexByte(line, '\n'); i >= 0 {
   304  			line, src = line[:i+1], line[i+1:]
   305  		} else {
   306  			src = nil
   307  		}
   308  		if len(line) > 0 && line[0] != '\n' { // not blank
   309  			b.Write(indent)
   310  		}
   311  		b.Write(line)
   312  	}
   313  	b.Write(after)
   314  	return b.Bytes()
   315  }
   316  
   317  var impLine = regexp.MustCompile(`^\s+(?:[\w\.]+\s+)?"(.+?)"`)
   318  
   319  func addImportSpaces(r io.Reader, breaks []string) ([]byte, error) {
   320  	var out bytes.Buffer
   321  	in := bufio.NewReader(r)
   322  	inImports := false
   323  	done := false
   324  	for {
   325  		s, err := in.ReadString('\n')
   326  		if err == io.EOF {
   327  			break
   328  		} else if err != nil {
   329  			return nil, err
   330  		}
   331  
   332  		if !inImports && !done && strings.HasPrefix(s, "import") {
   333  			inImports = true
   334  		}
   335  		if inImports && (strings.HasPrefix(s, "var") ||
   336  			strings.HasPrefix(s, "func") ||
   337  			strings.HasPrefix(s, "const") ||
   338  			strings.HasPrefix(s, "type")) {
   339  			done = true
   340  			inImports = false
   341  		}
   342  		if inImports && len(breaks) > 0 {
   343  			if m := impLine.FindStringSubmatch(s); m != nil {
   344  				if m[1] == breaks[0] {
   345  					out.WriteByte('\n')
   346  					breaks = breaks[1:]
   347  				}
   348  			}
   349  		}
   350  
   351  		fmt.Fprint(&out, s)
   352  	}
   353  	return out.Bytes(), nil
   354  }