github.com/jd-ly/tools@v0.5.7/internal/imports/imports.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:generate go run mkstdlib.go
     6  
     7  // Package imports implements a Go pretty-printer (like package "go/format")
     8  // that also adds or removes import statements as necessary.
     9  package imports
    10  
    11  import (
    12  	"bufio"
    13  	"bytes"
    14  	"fmt"
    15  	"go/ast"
    16  	"go/format"
    17  	"go/parser"
    18  	"go/printer"
    19  	"go/token"
    20  	"io"
    21  	"regexp"
    22  	"strconv"
    23  	"strings"
    24  
    25  	"github.com/jd-ly/tools/go/ast/astutil"
    26  )
    27  
    28  // Options is github.com/jd-ly/tools/imports.Options with extra internal-only options.
    29  type Options struct {
    30  	Env *ProcessEnv // The environment to use. Note: this contains the cached module and filesystem state.
    31  
    32  	// LocalPrefix is a comma-separated string of import path prefixes, which, if
    33  	// set, instructs Process to sort the import paths with the given prefixes
    34  	// into another group after 3rd-party packages.
    35  	LocalPrefix string
    36  
    37  	Fragment  bool // Accept fragment of a source file (no package statement)
    38  	AllErrors bool // Report all errors (not just the first 10 on different lines)
    39  
    40  	Comments  bool // Print comments (true if nil *Options provided)
    41  	TabIndent bool // Use tabs for indent (true if nil *Options provided)
    42  	TabWidth  int  // Tab width (8 if nil *Options provided)
    43  
    44  	FormatOnly bool // Disable the insertion and deletion of imports
    45  }
    46  
    47  // Process implements github.com/jd-ly/tools/imports.Process with explicit context in opt.Env.
    48  func Process(filename string, src []byte, opt *Options) (formatted []byte, err error) {
    49  	fileSet := token.NewFileSet()
    50  	file, adjust, err := parse(fileSet, filename, src, opt)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	if !opt.FormatOnly {
    56  		if err := fixImports(fileSet, file, filename, opt.Env); err != nil {
    57  			return nil, err
    58  		}
    59  	}
    60  	return formatFile(fileSet, file, src, adjust, opt)
    61  }
    62  
    63  // FixImports returns a list of fixes to the imports that, when applied,
    64  // will leave the imports in the same state as Process. src and opt must
    65  // be specified.
    66  //
    67  // Note that filename's directory influences which imports can be chosen,
    68  // so it is important that filename be accurate.
    69  func FixImports(filename string, src []byte, opt *Options) (fixes []*ImportFix, err error) {
    70  	fileSet := token.NewFileSet()
    71  	file, _, err := parse(fileSet, filename, src, opt)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	return getFixes(fileSet, file, filename, opt.Env)
    77  }
    78  
    79  // ApplyFixes applies all of the fixes to the file and formats it. extraMode
    80  // is added in when parsing the file. src and opts must be specified, but no
    81  // env is needed.
    82  func ApplyFixes(fixes []*ImportFix, filename string, src []byte, opt *Options, extraMode parser.Mode) (formatted []byte, err error) {
    83  	// Don't use parse() -- we don't care about fragments or statement lists
    84  	// here, and we need to work with unparseable files.
    85  	fileSet := token.NewFileSet()
    86  	parserMode := parser.Mode(0)
    87  	if opt.Comments {
    88  		parserMode |= parser.ParseComments
    89  	}
    90  	if opt.AllErrors {
    91  		parserMode |= parser.AllErrors
    92  	}
    93  	parserMode |= extraMode
    94  
    95  	file, err := parser.ParseFile(fileSet, filename, src, parserMode)
    96  	if file == nil {
    97  		return nil, err
    98  	}
    99  
   100  	// Apply the fixes to the file.
   101  	apply(fileSet, file, fixes)
   102  
   103  	return formatFile(fileSet, file, src, nil, opt)
   104  }
   105  
   106  func formatFile(fileSet *token.FileSet, file *ast.File, src []byte, adjust func(orig []byte, src []byte) []byte, opt *Options) ([]byte, error) {
   107  	mergeImports(fileSet, file)
   108  	sortImports(opt.LocalPrefix, fileSet, file)
   109  	imps := astutil.Imports(fileSet, file)
   110  	var spacesBefore []string // import paths we need spaces before
   111  	for _, impSection := range imps {
   112  		// Within each block of contiguous imports, see if any
   113  		// import lines are in different group numbers. If so,
   114  		// we'll need to put a space between them so it's
   115  		// compatible with gofmt.
   116  		lastGroup := -1
   117  		for _, importSpec := range impSection {
   118  			importPath, _ := strconv.Unquote(importSpec.Path.Value)
   119  			groupNum := importGroup(opt.LocalPrefix, importPath)
   120  			if groupNum != lastGroup && lastGroup != -1 {
   121  				spacesBefore = append(spacesBefore, importPath)
   122  			}
   123  			lastGroup = groupNum
   124  		}
   125  
   126  	}
   127  
   128  	printerMode := printer.UseSpaces
   129  	if opt.TabIndent {
   130  		printerMode |= printer.TabIndent
   131  	}
   132  	printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}
   133  
   134  	var buf bytes.Buffer
   135  	err := printConfig.Fprint(&buf, fileSet, file)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	out := buf.Bytes()
   140  	if adjust != nil {
   141  		out = adjust(src, out)
   142  	}
   143  	if len(spacesBefore) > 0 {
   144  		out, err = addImportSpaces(bytes.NewReader(out), spacesBefore)
   145  		if err != nil {
   146  			return nil, err
   147  		}
   148  	}
   149  
   150  	out, err = format.Source(out)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  	return out, nil
   155  }
   156  
   157  // parse parses src, which was read from filename,
   158  // as a Go source file or statement list.
   159  func parse(fset *token.FileSet, filename string, src []byte, opt *Options) (*ast.File, func(orig, src []byte) []byte, error) {
   160  	parserMode := parser.Mode(0)
   161  	if opt.Comments {
   162  		parserMode |= parser.ParseComments
   163  	}
   164  	if opt.AllErrors {
   165  		parserMode |= parser.AllErrors
   166  	}
   167  
   168  	// Try as whole source file.
   169  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   170  	if err == nil {
   171  		return file, nil, nil
   172  	}
   173  	// If the error is that the source file didn't begin with a
   174  	// package line and we accept fragmented input, fall through to
   175  	// try as a source fragment.  Stop and return on any other error.
   176  	if !opt.Fragment || !strings.Contains(err.Error(), "expected 'package'") {
   177  		return nil, nil, err
   178  	}
   179  
   180  	// If this is a declaration list, make it a source file
   181  	// by inserting a package clause.
   182  	// Insert using a ;, not a newline, so that parse errors are on
   183  	// the correct line.
   184  	const prefix = "package main;"
   185  	psrc := append([]byte(prefix), src...)
   186  	file, err = parser.ParseFile(fset, filename, psrc, parserMode)
   187  	if err == nil {
   188  		// Gofmt will turn the ; into a \n.
   189  		// Do that ourselves now and update the file contents,
   190  		// so that positions and line numbers are correct going forward.
   191  		psrc[len(prefix)-1] = '\n'
   192  		fset.File(file.Package).SetLinesForContent(psrc)
   193  
   194  		// If a main function exists, we will assume this is a main
   195  		// package and leave the file.
   196  		if containsMainFunc(file) {
   197  			return file, nil, nil
   198  		}
   199  
   200  		adjust := func(orig, src []byte) []byte {
   201  			// Remove the package clause.
   202  			src = src[len(prefix):]
   203  			return matchSpace(orig, src)
   204  		}
   205  		return file, adjust, nil
   206  	}
   207  	// If the error is that the source file didn't begin with a
   208  	// declaration, fall through to try as a statement list.
   209  	// Stop and return on any other error.
   210  	if !strings.Contains(err.Error(), "expected declaration") {
   211  		return nil, nil, err
   212  	}
   213  
   214  	// If this is a statement list, make it a source file
   215  	// by inserting a package clause and turning the list
   216  	// into a function body.  This handles expressions too.
   217  	// Insert using a ;, not a newline, so that the line numbers
   218  	// in fsrc match the ones in src.
   219  	fsrc := append(append([]byte("package p; func _() {"), src...), '}')
   220  	file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
   221  	if err == nil {
   222  		adjust := func(orig, src []byte) []byte {
   223  			// Remove the wrapping.
   224  			// Gofmt has turned the ; into a \n\n.
   225  			src = src[len("package p\n\nfunc _() {"):]
   226  			src = src[:len(src)-len("}\n")]
   227  			// Gofmt has also indented the function body one level.
   228  			// Remove that indent.
   229  			src = bytes.Replace(src, []byte("\n\t"), []byte("\n"), -1)
   230  			return matchSpace(orig, src)
   231  		}
   232  		return file, adjust, nil
   233  	}
   234  
   235  	// Failed, and out of options.
   236  	return nil, nil, err
   237  }
   238  
   239  // containsMainFunc checks if a file contains a function declaration with the
   240  // function signature 'func main()'
   241  func containsMainFunc(file *ast.File) bool {
   242  	for _, decl := range file.Decls {
   243  		if f, ok := decl.(*ast.FuncDecl); ok {
   244  			if f.Name.Name != "main" {
   245  				continue
   246  			}
   247  
   248  			if len(f.Type.Params.List) != 0 {
   249  				continue
   250  			}
   251  
   252  			if f.Type.Results != nil && len(f.Type.Results.List) != 0 {
   253  				continue
   254  			}
   255  
   256  			return true
   257  		}
   258  	}
   259  
   260  	return false
   261  }
   262  
   263  func cutSpace(b []byte) (before, middle, after []byte) {
   264  	i := 0
   265  	for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {
   266  		i++
   267  	}
   268  	j := len(b)
   269  	for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') {
   270  		j--
   271  	}
   272  	if i <= j {
   273  		return b[:i], b[i:j], b[j:]
   274  	}
   275  	return nil, nil, b[j:]
   276  }
   277  
   278  // matchSpace reformats src to use the same space context as orig.
   279  // 1) If orig begins with blank lines, matchSpace inserts them at the beginning of src.
   280  // 2) matchSpace copies the indentation of the first non-blank line in orig
   281  //    to every non-blank line in src.
   282  // 3) matchSpace copies the trailing space from orig and uses it in place
   283  //   of src's trailing space.
   284  func matchSpace(orig []byte, src []byte) []byte {
   285  	before, _, after := cutSpace(orig)
   286  	i := bytes.LastIndex(before, []byte{'\n'})
   287  	before, indent := before[:i+1], before[i+1:]
   288  
   289  	_, src, _ = cutSpace(src)
   290  
   291  	var b bytes.Buffer
   292  	b.Write(before)
   293  	for len(src) > 0 {
   294  		line := src
   295  		if i := bytes.IndexByte(line, '\n'); i >= 0 {
   296  			line, src = line[:i+1], line[i+1:]
   297  		} else {
   298  			src = nil
   299  		}
   300  		if len(line) > 0 && line[0] != '\n' { // not blank
   301  			b.Write(indent)
   302  		}
   303  		b.Write(line)
   304  	}
   305  	b.Write(after)
   306  	return b.Bytes()
   307  }
   308  
   309  var impLine = regexp.MustCompile(`^\s+(?:[\w\.]+\s+)?"(.+)"`)
   310  
   311  func addImportSpaces(r io.Reader, breaks []string) ([]byte, error) {
   312  	var out bytes.Buffer
   313  	in := bufio.NewReader(r)
   314  	inImports := false
   315  	done := false
   316  	for {
   317  		s, err := in.ReadString('\n')
   318  		if err == io.EOF {
   319  			break
   320  		} else if err != nil {
   321  			return nil, err
   322  		}
   323  
   324  		if !inImports && !done && strings.HasPrefix(s, "import") {
   325  			inImports = true
   326  		}
   327  		if inImports && (strings.HasPrefix(s, "var") ||
   328  			strings.HasPrefix(s, "func") ||
   329  			strings.HasPrefix(s, "const") ||
   330  			strings.HasPrefix(s, "type")) {
   331  			done = true
   332  			inImports = false
   333  		}
   334  		if inImports && len(breaks) > 0 {
   335  			if m := impLine.FindStringSubmatch(s); m != nil {
   336  				if m[1] == breaks[0] {
   337  					out.WriteByte('\n')
   338  					breaks = breaks[1:]
   339  				}
   340  			}
   341  		}
   342  
   343  		fmt.Fprint(&out, s)
   344  	}
   345  	return out.Bytes(), nil
   346  }