github.com/daixiang0/gci@v0.13.0/pkg/gci/gci.go (about)

     1  package gci
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	goFormat "go/format"
     8  	"os"
     9  	"sync"
    10  
    11  	"github.com/hexops/gotextdiff"
    12  	"github.com/hexops/gotextdiff/myers"
    13  	"github.com/hexops/gotextdiff/span"
    14  	"golang.org/x/sync/errgroup"
    15  
    16  	"github.com/daixiang0/gci/pkg/config"
    17  	"github.com/daixiang0/gci/pkg/format"
    18  	"github.com/daixiang0/gci/pkg/io"
    19  	"github.com/daixiang0/gci/pkg/log"
    20  	"github.com/daixiang0/gci/pkg/parse"
    21  	"github.com/daixiang0/gci/pkg/section"
    22  	"github.com/daixiang0/gci/pkg/utils"
    23  )
    24  
    25  func LocalFlagsToSections(localFlags []string) section.SectionList {
    26  	sections := section.DefaultSections()
    27  	// Add all local arguments as ImportPrefix sections
    28  	// for _, l := range localFlags {
    29  	// 	sections = append(sections, section.Section{l, nil, nil})
    30  	// }
    31  	return sections
    32  }
    33  
    34  func PrintFormattedFiles(paths []string, cfg config.Config) error {
    35  	return processStdInAndGoFilesInPaths(paths, cfg, func(filePath string, unmodifiedFile, formattedFile []byte) error {
    36  		fmt.Print(string(formattedFile))
    37  		return nil
    38  	})
    39  }
    40  
    41  func WriteFormattedFiles(paths []string, cfg config.Config) error {
    42  	return processGoFilesInPaths(paths, cfg, func(filePath string, unmodifiedFile, formattedFile []byte) error {
    43  		if bytes.Equal(unmodifiedFile, formattedFile) {
    44  			log.L().Debug(fmt.Sprintf("Skipping correctly formatted File: %s", filePath))
    45  			return nil
    46  		}
    47  		log.L().Info(fmt.Sprintf("Writing formatted File: %s", filePath))
    48  		return os.WriteFile(filePath, formattedFile, 0o644)
    49  	})
    50  }
    51  
    52  func ListUnFormattedFiles(paths []string, cfg config.Config) error {
    53  	return processGoFilesInPaths(paths, cfg, func(filePath string, unmodifiedFile, formattedFile []byte) error {
    54  		if bytes.Equal(unmodifiedFile, formattedFile) {
    55  			return nil
    56  		}
    57  		fmt.Println(filePath)
    58  		return nil
    59  	})
    60  }
    61  
    62  func DiffFormattedFiles(paths []string, cfg config.Config) error {
    63  	return processStdInAndGoFilesInPaths(paths, cfg, func(filePath string, unmodifiedFile, formattedFile []byte) error {
    64  		fileURI := span.URIFromPath(filePath)
    65  		edits := myers.ComputeEdits(fileURI, string(unmodifiedFile), string(formattedFile))
    66  		unifiedEdits := gotextdiff.ToUnified(filePath, filePath, string(unmodifiedFile), edits)
    67  		fmt.Printf("%v", unifiedEdits)
    68  		return nil
    69  	})
    70  }
    71  
    72  func DiffFormattedFilesToArray(paths []string, cfg config.Config, diffs *[]string, lock *sync.Mutex) error {
    73  	log.InitLogger()
    74  	defer log.L().Sync()
    75  	return processStdInAndGoFilesInPaths(paths, cfg, func(filePath string, unmodifiedFile, formattedFile []byte) error {
    76  		fileURI := span.URIFromPath(filePath)
    77  		edits := myers.ComputeEdits(fileURI, string(unmodifiedFile), string(formattedFile))
    78  		unifiedEdits := gotextdiff.ToUnified(filePath, filePath, string(unmodifiedFile), edits)
    79  		lock.Lock()
    80  		*diffs = append(*diffs, fmt.Sprint(unifiedEdits))
    81  		lock.Unlock()
    82  		return nil
    83  	})
    84  }
    85  
    86  type fileFormattingFunc func(filePath string, unmodifiedFile, formattedFile []byte) error
    87  
    88  func processStdInAndGoFilesInPaths(paths []string, cfg config.Config, fileFunc fileFormattingFunc) error {
    89  	return ProcessFiles(io.StdInGenerator.Combine(io.GoFilesInPathsGenerator(paths, cfg.SkipVendor)), cfg, fileFunc)
    90  }
    91  
    92  func processGoFilesInPaths(paths []string, cfg config.Config, fileFunc fileFormattingFunc) error {
    93  	return ProcessFiles(io.GoFilesInPathsGenerator(paths, cfg.SkipVendor), cfg, fileFunc)
    94  }
    95  
    96  func ProcessFiles(fileGenerator io.FileGeneratorFunc, cfg config.Config, fileFunc fileFormattingFunc) error {
    97  	var taskGroup errgroup.Group
    98  	files, err := fileGenerator()
    99  	if err != nil {
   100  		return err
   101  	}
   102  	for _, file := range files {
   103  		// run file processing in parallel
   104  		taskGroup.Go(processingFunc(file, cfg, fileFunc))
   105  	}
   106  	return taskGroup.Wait()
   107  }
   108  
   109  func processingFunc(file io.FileObj, cfg config.Config, formattingFunc fileFormattingFunc) func() error {
   110  	return func() error {
   111  		unmodifiedFile, formattedFile, err := LoadFormatGoFile(file, cfg)
   112  		if err != nil {
   113  			// if errors.Is(err, FileParsingError{}) {
   114  			// 	// do not process files that are improperly formatted
   115  			// 	return nil
   116  			// }
   117  			return err
   118  		}
   119  		return formattingFunc(file.Path(), unmodifiedFile, formattedFile)
   120  	}
   121  }
   122  
   123  func LoadFormatGoFile(file io.FileObj, cfg config.Config) (src, dist []byte, err error) {
   124  	src, err = file.Load()
   125  	log.L().Debug(fmt.Sprintf("Loaded File: %s", file.Path()))
   126  	if err != nil {
   127  		return nil, nil, err
   128  	}
   129  
   130  	return LoadFormat(src, file.Path(), cfg)
   131  }
   132  
   133  func LoadFormat(in []byte, path string, cfg config.Config) (src, dist []byte, err error) {
   134  	src = in
   135  
   136  	if cfg.SkipGenerated && parse.IsGeneratedFileByComment(string(src)) {
   137  		return src, src, nil
   138  	}
   139  
   140  	imports, headEnd, tailStart, cStart, cEnd, err := parse.ParseFile(src, path)
   141  	if err != nil {
   142  		if errors.Is(err, parse.NoImportError{}) {
   143  			return src, src, nil
   144  		}
   145  		return nil, nil, err
   146  	}
   147  
   148  	// do not do format if only one import
   149  	if len(imports) <= 1 {
   150  		return src, src, nil
   151  	}
   152  
   153  	result, err := format.Format(imports, &cfg)
   154  	if err != nil {
   155  		return nil, nil, err
   156  	}
   157  
   158  	firstWithIndex := true
   159  
   160  	var body []byte
   161  
   162  	// order by section list
   163  	for _, s := range cfg.Sections {
   164  		if len(result[s.String()]) > 0 {
   165  			if len(body) > 0 {
   166  				body = append(body, utils.Linebreak)
   167  			}
   168  			for _, d := range result[s.String()] {
   169  				AddIndent(&body, &firstWithIndex)
   170  				body = append(body, src[d.Start:d.End]...)
   171  			}
   172  		}
   173  	}
   174  
   175  	head := make([]byte, headEnd)
   176  	copy(head, src[:headEnd])
   177  	tail := make([]byte, len(src)-tailStart)
   178  	copy(tail, src[tailStart:])
   179  
   180  	// ensure C
   181  	if cStart != 0 {
   182  		head = append(head, src[cStart:cEnd]...)
   183  		head = append(head, utils.Linebreak)
   184  	}
   185  
   186  	// add beginning of import block
   187  	head = append(head, `import (`...)
   188  	head = append(head, utils.Linebreak)
   189  	// add end of import block
   190  	body = append(body, []byte{utils.RightParenthesis, utils.Linebreak}...)
   191  
   192  	log.L().Debug(fmt.Sprintf("head:\n%s", head))
   193  	log.L().Debug(fmt.Sprintf("body:\n%s", body))
   194  	if len(tail) > 20 {
   195  		log.L().Debug(fmt.Sprintf("tail:\n%s", tail[:20]))
   196  	} else {
   197  		log.L().Debug(fmt.Sprintf("tail:\n%s", tail))
   198  	}
   199  
   200  	var totalLen int
   201  	slices := [][]byte{head, body, tail}
   202  	for _, s := range slices {
   203  		totalLen += len(s)
   204  	}
   205  	dist = make([]byte, totalLen)
   206  	var i int
   207  	for _, s := range slices {
   208  		i += copy(dist[i:], s)
   209  	}
   210  
   211  	// remove ^M(\r\n) from Win to Unix
   212  	dist = bytes.ReplaceAll(dist, []byte{utils.WinLinebreak}, []byte{utils.Linebreak})
   213  
   214  	log.L().Debug(fmt.Sprintf("raw:\n%s", dist))
   215  	dist, err = goFormat.Source(dist)
   216  	if err != nil {
   217  		return nil, nil, err
   218  	}
   219  
   220  	return src, dist, nil
   221  }
   222  
   223  func AddIndent(in *[]byte, first *bool) {
   224  	if *first {
   225  		*first = false
   226  		return
   227  	}
   228  	*in = append(*in, utils.Indent)
   229  }