github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/format/format.go (about)

     1  package format
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/build"
     8  	"go/format"
     9  	"go/parser"
    10  	"go/token"
    11  	"os"
    12  	"runtime"
    13  	"strconv"
    14  	"strings"
    15  )
    16  
    17  const (
    18  	Test = 1
    19  )
    20  
    21  type ImportGroups [][]string
    22  
    23  func getAstString(fileSet *token.FileSet, node ast.Node) string {
    24  	buf := &bytes.Buffer{}
    25  	if err := format.Node(buf, fileSet, node); err != nil {
    26  		panic(err)
    27  	}
    28  	return buf.String()
    29  }
    30  
    31  func Format(filename string, src []byte) []byte {
    32  	fileSet := token.NewFileSet()
    33  	file, err := parser.ParseFile(fileSet, filename, src, parser.ParseComments)
    34  	if err != nil {
    35  		panic(fmt.Errorf("errors %s in %s", err.Error(), filename))
    36  	}
    37  	buf := &bytes.Buffer{}
    38  	if err := format.Node(buf, fileSet, file); err != nil {
    39  		panic(fmt.Errorf("errors %s in %s", err.Error(), filename))
    40  	}
    41  	return buf.Bytes()
    42  }
    43  
    44  func Process(filename string, src []byte) ([]byte, error) {
    45  	cwd, _ := os.Getwd()
    46  	fileSet := token.NewFileSet()
    47  	file, err := parser.ParseFile(fileSet, filename, src, parser.ParseComments)
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	ast.SortImports(fileSet, file)
    53  
    54  	formattedCode := getAstString(fileSet, file)
    55  
    56  	for _, decl := range file.Decls {
    57  		if genDecl, ok := decl.(*ast.GenDecl); ok {
    58  			if genDecl.Tok != token.IMPORT {
    59  				break
    60  			}
    61  
    62  			importsCode := getAstString(fileSet, genDecl)
    63  
    64  			importGroups := make(ImportGroups, 4)
    65  			for _, spec := range genDecl.Specs {
    66  				importSpec := spec.(*ast.ImportSpec)
    67  				importPath, _ := strconv.Unquote(importSpec.Path.Value)
    68  				pkg, err := build.Import(importPath, "", build.ImportComment)
    69  				if err != nil {
    70  					panic(fmt.Errorf("errors %s in %s", err.Error(), filename))
    71  				}
    72  				if strings.Contains(pkg.Dir, runtime.GOROOT()) {
    73  					// libexec
    74  					importGroups[0] = append(importGroups[0], getAstString(fileSet, importSpec))
    75  				} else {
    76  					if strings.HasPrefix(pkg.Dir, cwd) {
    77  						importGroups[3] = append(importGroups[3], getAstString(fileSet, importSpec))
    78  					} else {
    79  						if strings.HasPrefix(pkg.ImportPath, "git.chinawayltd.com/golib") || strings.HasPrefix(pkg.ImportPath, "g7pay") {
    80  							importGroups[2] = append(importGroups[2], getAstString(fileSet, importSpec))
    81  						} else {
    82  							importGroups[1] = append(importGroups[1], getAstString(fileSet, importSpec))
    83  						}
    84  					}
    85  				}
    86  			}
    87  
    88  			buf := &bytes.Buffer{}
    89  
    90  			buf.WriteString("import (\n")
    91  			for _, importGroup := range importGroups {
    92  				for _, code := range importGroup {
    93  					buf.WriteString(code + "\n")
    94  				}
    95  				buf.WriteString("\n")
    96  			}
    97  			buf.WriteString(")")
    98  			formattedCode = strings.Replace(formattedCode, importsCode, buf.String(), -1)
    99  		}
   100  	}
   101  
   102  	return Format(filename, []byte(formattedCode)), nil
   103  }